Feat/matmul/faster (#479)

This commit is contained in:
Nathaniel Simard 2023-07-07 12:00:37 -04:00 committed by GitHub
parent 261aa952c0
commit 513b9281c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1278 additions and 496 deletions

View File

@ -1,11 +1,13 @@
use std::marker::PhantomData;
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_wgpu::{
benchmark::Benchmark,
kernel::{matmul_mem_coalescing_default, matmul_naive_default, matmul_tiling_2d_default},
kernel::matmul::{
continuous, continuous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
tile, tile_vectorized,
},
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
};
use std::marker::PhantomData;
trait MatmulFunction<B: Backend, const D: usize> {
fn run(lhs: Tensor<B, D>, rhs: Tensor<B, D>) -> Tensor<B, D>;
@ -37,6 +39,10 @@ where
)
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
F::run(lhs.clone(), rhs.clone());
@ -51,71 +57,69 @@ where
}
}
struct Tiling2DMatmul;
macro_rules! benchmark {
($name:ident, $func:expr) => {
struct $name;
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
for Tiling2DMatmul
{
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
Tensor::from_primitive(matmul_tiling_2d_default(
lhs.into_primitive(),
rhs.into_primitive(),
))
}
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D> for $name {
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
Tensor::from_primitive($func(lhs.into_primitive(), rhs.into_primitive()))
}
}
};
}
struct NaiveMatmul;
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D> for NaiveMatmul {
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
Tensor::from_primitive(matmul_naive_default(
lhs.into_primitive(),
rhs.into_primitive(),
))
}
}
struct MemCoalescingMatmul;
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
for MemCoalescingMatmul
{
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
Tensor::from_primitive(matmul_mem_coalescing_default(
lhs.into_primitive(),
rhs.into_primitive(),
))
}
}
benchmark!(NaiveMatmul, matmul_naive_default);
benchmark!(MemCoalescingMatmul, matmul_mem_coalescing_default);
benchmark!(
Tiling2DMatmulContinuous,
continuous::matmul_tiling_2d_default
);
benchmark!(Tiling2DMatmulTile, tile::matmul_tiling_2d_default);
benchmark!(
Tiling2DMatmulTileVectorized,
tile_vectorized::matmul_tiling_2d_default
);
benchmark!(
Tiling2DMatmulContinuousVectorized,
continuous_vectorized::matmul_tiling_2d_default
);
fn main() {
let batch_size = 32;
let matrix_size = 128;
run_benchmark!(MatmulBenchmark::<NaiveMatmul, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats: 10,
matmul: PhantomData::default()
});
let num_repeats = 3;
let batch_size = 3;
let matrix_size = 1000;
run_benchmark!(MatmulBenchmark::<MemCoalescingMatmul, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats: 10,
num_repeats,
matmul: PhantomData::default()
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmul, 3> {
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContinuous, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats: 10,
num_repeats,
matmul: PhantomData::default()
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContinuousVectorized, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats,
matmul: PhantomData::default()
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTile, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats,
matmul: PhantomData::default()
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTileVectorized, 3> {
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
num_repeats,
matmul: PhantomData::default()
});
}

View File

@ -1,3 +1,4 @@
use super::utils::shape_out;
use crate::{
context::WorkGroup,
element::WgpuElement,
@ -5,11 +6,10 @@ use crate::{
kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
kernel_wgsl!(
MatmulMemCoalescingRaw,
"../../template/matmul_mem_coalescing.wgsl"
"../../template/matmul/mem_coalescing.wgsl"
);
struct MatmulMemCoalescing<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
@ -44,21 +44,9 @@ pub fn matmul_mem_coalescing<
) -> WgpuTensor<E, D> {
lhs.assert_is_on_same_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = shape_out(&lhs, &rhs);
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
shape_out[D - 2] = num_rows;
shape_out[D - 1] = num_cols;
let shape_out = Shape::new(shape_out);
let buffer = lhs
.context
@ -103,10 +91,7 @@ pub fn matmul_mem_coalescing<
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::TestTensor;
pub type ReferenceTensor<const D: usize> =
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
use crate::kernel::matmul::utils::tests::same_as_reference;
#[test]
pub fn test_matmul_mem_coalescing_straightforward() {
@ -167,25 +152,4 @@ mod tests {
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(func, shape_lhs, shape_rhs);
}
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
where
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
S: Into<Shape<D>>,
{
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let x_wgpu = TestTensor::from_data(x.to_data());
let y_wgpu = TestTensor::from_data(y.to_data());
let z_reference = x.matmul(y);
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
let z = TestTensor::from_primitive(z);
println!("{z}");
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
}

View File

@ -1,3 +1,5 @@
pub(crate) mod utils;
mod mem_coalescing;
mod naive;
mod tiling2d;

View File

@ -1,3 +1,4 @@
use super::utils::shape_out;
use crate::{
context::WorkGroup,
element::WgpuElement,
@ -5,9 +6,8 @@ use crate::{
kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul_naive.wgsl");
kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl");
struct MatmulNaive<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
@ -41,21 +41,10 @@ pub fn matmul_naive<
) -> WgpuTensor<E, D> {
lhs.assert_is_on_same_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let shape_out = shape_out(&lhs, &rhs);
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
shape_out[D - 2] = num_rows;
shape_out[D - 1] = num_cols;
let shape_out = Shape::new(shape_out);
let buffer = lhs
.context
@ -100,10 +89,7 @@ pub fn matmul_naive<
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::TestTensor;
pub type ReferenceTensor<const D: usize> =
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
use crate::kernel::matmul::utils::tests::same_as_reference;
#[test]
pub fn test_matmul_naive_straightforward() {
@ -162,25 +148,4 @@ mod tests {
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(func, shape_lhs, shape_rhs);
}
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
where
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
S: Into<Shape<D>>,
{
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let x_wgpu = TestTensor::from_data(x.to_data());
let y_wgpu = TestTensor::from_data(y.to_data());
let z_reference = x.matmul(y);
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
let z = TestTensor::from_primitive(z);
println!("{z}");
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
}

View File

@ -1,330 +0,0 @@
use std::cmp::{max, min};
use crate::{
context::WorkGroup,
element::WgpuElement,
kernel::{build_info, KernelSettings, SourceTemplate, StaticKernel},
kernel_wgsl,
tensor::WgpuTensor,
};
use burn_tensor::Shape;
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
kernel_wgsl!(
MatmulTiling2DRaw,
"../../template/matmul_blocktiling_2d.wgsl"
);
struct MatmulTiling2D<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>;
impl<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
> StaticKernel for MatmulTiling2D<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
{
fn source_template() -> SourceTemplate {
MatmulTiling2DRaw::source_template()
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk", (B_M * B_K).to_string())
.register("bk_x_bn", (B_K * B_N).to_string())
.register("t_m", T_M.to_string())
.register("t_n", T_N.to_string())
.register("tm_x_tn", (T_M * T_N).to_string())
}
}
/// Matrix multiplication using tiling 2D algorithm with default parameters
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
// Block size along dim M
const B_M: usize = 128;
// // Block size along dim N
const B_N: usize = 128;
// // Block size along dim K
const B_K: usize = 8;
// // Tiling size along dim M
const T_M: usize = 8;
// // Tiling size along dim N
const T_N: usize = 8;
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
const WORKGROUP_SIZE_X: usize = 16;
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
const WORKGROUP_SIZE_Y: usize = 16;
matmul_tiling_2d::<E, D, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(lhs, rhs)
}
/// Matrix multiplication using tiling 2D algorithm with custom parameters
pub fn matmul_tiling_2d<
E: WgpuElement,
const D: usize,
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
assert!(B_K <= min(B_M, B_N), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
assert!(B_K * max(B_M, B_N) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
assert!(
WORKGROUP_SIZE_X == f32::ceil(B_M as f32 / T_M as f32) as usize,
"Workgroup size x must equal ceil(B_M / T_M)"
);
assert!(
WORKGROUP_SIZE_Y == f32::ceil(B_N as f32 / T_N as f32) as usize,
"Workgroup size y must equal ceil(B_N / T_N)"
);
lhs.assert_is_on_same_device(&rhs);
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
let num_rows = lhs.shape.dims[D - 2];
let num_cols = rhs.shape.dims[D - 1];
shape_out[D - 2] = num_rows;
shape_out[D - 1] = num_cols;
let shape_out = Shape::new(shape_out);
let buffer = lhs
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
// set number of workgroups
let blocks_needed_in_x = f32::ceil(num_rows as f32 / B_M as f32) as u32;
let blocks_needed_in_y = f32::ceil(num_cols as f32 / B_N as f32) as u32;
let kernel = lhs.context.compile_static::<KernelSettings<
MatmulTiling2D<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
E,
i32,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
1,
>>();
let info = build_info(&[&lhs, &rhs, &output]);
let info_buffers = lhs
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
let mut num_iter = 1;
for i in 0..D - 2 {
num_iter *= output.shape.dims[i];
}
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
lhs.context.execute(
workgroup,
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
output
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::TestTensor;
pub type ReferenceTensor<const D: usize> =
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
#[test]
pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() {
test_with_params::<128, 128, 16, 8, 8, 16, 16>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_m_not_equals_n() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 8, 3, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_smaller_than_m_n() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 3, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_larger_than_m_n() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 48, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_t_divides_b_unevenly() {
test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_small_parameters() {
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_bm_not_equals_bn() {
test_with_params::<32, 128, 8, 8, 8, 4, 16>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multibatch_1_dim() {
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multibatch_2_dims() {
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 4);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_memory_busted_should_panic() {
test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() {
test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() {
test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() {
test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multiple_blocks() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(32, 32, 32, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_bigger_than_bk() {
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 10, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(31, 23, 17, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_large_parameters() {
test_with_params::<256, 256, 16, 16, 16, 16, 16>(40, 40, 40, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_shapes_slightly_larger_than_blocks() {
test_with_params::<32, 32, 8, 8, 8, 4, 4>(40, 40, 30, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() {
test_with_params::<16, 16, 8, 8, 8, 2, 2>(50, 50, 50, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_tm_larger_than_bm() {
test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_tn_larger_than_bn() {
test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_uneven_parameters() {
test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_uneven_parameters_2() {
test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1);
}
fn test_with_params<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
m: usize,
k: usize,
n: usize,
batch_1: usize,
batch_2: usize,
) {
let func = |lhs, rhs| {
matmul_tiling_2d::<f32, 4, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
lhs, rhs,
)
};
let shape_lhs = [batch_1, batch_2, m, k];
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(func, shape_lhs, shape_rhs);
}
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
where
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
S: Into<Shape<D>>,
{
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let x_wgpu = TestTensor::from_data(x.to_data());
let y_wgpu = TestTensor::from_data(y.to_data());
let z_reference = x.matmul(y);
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
let z = TestTensor::from_primitive(z);
println!("{z}");
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
}

View File

@ -0,0 +1,415 @@
use crate::{
context::{Context, WorkGroup},
element::WgpuElement,
kernel::{build_info, matmul::utils::shape_out, SourceTemplate},
tensor::WgpuTensor,
};
use burn_tensor::Shape;
use std::{
cmp::{max, min},
sync::Arc,
};
use wgpu::ComputePipeline;
use super::padding::{crop, pad_round};
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
pub(super) fn empty_from_context<E: WgpuElement, const D: usize>(
context: Arc<Context>,
shape: &Shape<D>,
) -> WgpuTensor<E, D> {
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context, shape.clone(), buffer)
}
/// Create a source template for tile 2d matmul.
#[macro_export(local_inner_macros)]
macro_rules! matmul_tile_2d {
(
$struct:ident,
$file:expr
) => {
matmul_tile_2d!(
$struct,
$file,
B_M 64,
B_N 64,
B_K 32,
T_M 4,
T_N 4
);
};
(
$struct:ident,
$file:expr,
B_M $bm:expr,
B_N $bn:expr,
B_K $bk:expr,
T_M $tm:expr,
T_N $tn:expr
) => {
struct $struct<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>;
impl<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
> StaticKernel
for $struct<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
{
fn source_template() -> SourceTemplate {
kernel_wgsl!(Raw, $file);
register_template::<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
Raw::source_template(),
)
}
}
/// Matrix multiplication using tiling 2D algorithm with default parameters
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
// Block size along dim M
const B_M: usize = $bm;
// // Block size along dim N
const B_N: usize = $bn;
// // Block size along dim K
const B_K: usize = $bk;
// // Tiling size along dim M
const T_M: usize = $tm;
// // Tiling size along dim N
const T_N: usize = $tn;
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
const WORKGROUP_SIZE_X: usize = B_M / T_M;
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
const WORKGROUP_SIZE_Y: usize = B_N / T_N;
matmul_tiling_2d::<E, D, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
lhs, rhs,
)
}
/// Matrix multiplication using tiling 2D algorithm with custom parameters
pub fn matmul_tiling_2d<
E: WgpuElement,
const D: usize,
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
) -> WgpuTensor<E, D> {
let kernel = lhs.context.compile_static::<KernelSettings<
$struct<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
E,
i32,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
1,
>>();
matmul_tiling_2d_launch::<
E,
D,
B_M,
B_N,
B_K,
T_M,
T_N,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
>(lhs, rhs, kernel)
}
#[cfg(test)]
mod tests {
use super::*;
use $crate::kernel::matmul::utils::tests::same_as_reference;
#[test]
pub fn test_matmul_tiling_2d_large_blocks() {
test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() {
test_with_params::<64, 64, 8, 4, 4, 16, 16>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_m_not_equals_n() {
test_with_params::<16, 16, 8, 2, 2, 8, 8>(16, 8, 16, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_smaller_than_m_n() {
test_with_params::<16, 16, 4, 2, 2, 8, 8>(16, 4, 16, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_larger_than_m_n() {
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 48, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_t_divides_b_unevenly_should_panic() {
test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_bm_not_equals_bn() {
test_with_params::<8, 16, 8, 2, 4, 4, 4>(8, 8, 16, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multibatch_1_dim() {
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multibatch_2_dims() {
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 4);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_memory_busted_should_panic() {
test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() {
test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() {
test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() {
test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_multiple_blocks() {
test_with_params::<16, 16, 8, 2, 2, 8, 8>(32, 32, 32, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_k_bigger_than_bk() {
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 16, 8, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() {
test_with_params::<16, 16, 8, 2, 2, 8, 8>(31, 23, 17, 1, 1);
}
#[test]
pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() {
test_with_params::<16, 16, 8, 2, 2, 8, 8>(48, 48, 48, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_tm_larger_than_bm_should_panic() {
test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_tn_larger_than_bn_should_panic() {
test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_uneven_parameters_should_panic() {
test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1);
}
#[test]
#[should_panic]
pub fn test_matmul_tiling_2d_uneven_parameters_2_should_panic() {
test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1);
}
fn test_with_params<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
m: usize,
k: usize,
n: usize,
batch_1: usize,
batch_2: usize,
) {
let func = |lhs, rhs| {
matmul_tiling_2d::<f32, 4, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
lhs, rhs,
)
};
let shape_lhs = [batch_1, batch_2, m, k];
let shape_rhs = [batch_1, batch_2, k, n];
same_as_reference(func, shape_lhs, shape_rhs);
}
}
};
}
pub(super) fn register_template<
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
template: SourceTemplate,
) -> SourceTemplate {
template
.register("b_m", B_M.to_string())
.register("b_n", B_N.to_string())
.register("b_k", B_K.to_string())
.register("bm_x_bk", (B_M * B_K).to_string())
.register("bk_x_bn", (B_K * B_N).to_string())
.register("t_m", T_M.to_string())
.register("t_n", T_N.to_string())
.register("tm_x_tn", (T_M * T_N).to_string())
}
#[allow(clippy::too_many_arguments)]
pub(super) fn matmul_parameter_assertions<E: WgpuElement, const D: usize>(
b_m: usize,
b_n: usize,
b_k: usize,
t_m: usize,
t_n: usize,
workgroup_size_x: usize,
workgroup_size_y: usize,
lhs: &WgpuTensor<E, D>,
rhs: &WgpuTensor<E, D>,
) {
assert!(b_k <= min(b_m, b_n), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
assert!(b_k * max(b_m, b_n) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
assert!(
b_m % t_m == 0 && b_n % t_n == 0,
"T_M must divide B_M in this version"
);
assert!(
workgroup_size_x == b_m / t_m,
"Workgroup size x must equal B_M / T_M"
);
assert!(
workgroup_size_y == b_n / t_n,
"Workgroup size y must equal B_N / T_N"
);
lhs.assert_is_on_same_device(rhs);
}
pub(super) fn make_workgroup<const D: usize>(
output_shape: Shape<D>,
b_m: usize,
b_n: usize,
) -> WorkGroup {
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / b_m as f32) as u32;
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / b_n as f32) as u32;
let mut num_blocks_z = 1;
for i in 0..D - 2 {
num_blocks_z *= output_shape.dims[i];
}
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
}
pub(super) fn make_info_buffers<E: WgpuElement, const D: usize>(
lhs: &WgpuTensor<E, D>,
rhs: &WgpuTensor<E, D>,
output: &WgpuTensor<E, D>,
) -> Arc<wgpu::Buffer> {
let info = build_info(&[lhs, rhs, output]);
rhs.context
.create_buffer_with_data(bytemuck::cast_slice(&info))
}
pub(super) fn matmul_tiling_2d_launch<
E: WgpuElement,
const D: usize,
const B_M: usize,
const B_N: usize,
const B_K: usize,
const T_M: usize,
const T_N: usize,
const WORKGROUP_SIZE_X: usize,
const WORKGROUP_SIZE_Y: usize,
>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
kernel: Arc<ComputePipeline>,
) -> WgpuTensor<E, D> {
matmul_parameter_assertions::<E, D>(
B_M,
B_N,
B_K,
T_M,
T_N,
WORKGROUP_SIZE_X,
WORKGROUP_SIZE_Y,
&lhs,
&rhs,
);
let final_output_shape = shape_out(&lhs, &rhs);
let lhs = pad_round(lhs, B_M, B_K);
let rhs = pad_round(rhs, B_K, B_N);
let rounded_output_shape = shape_out(&lhs, &rhs);
let output = empty_from_context::<E, D>(rhs.context.clone(), &rounded_output_shape);
let workgroup = make_workgroup(rounded_output_shape, B_M, B_N);
let info_buffers = make_info_buffers(&lhs, &rhs, &output);
lhs.context.execute(
workgroup,
kernel,
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
);
crop(output, final_output_shape)
}

View File

@ -0,0 +1,12 @@
use super::base::{matmul_tiling_2d_launch, register_template};
use crate::{
element::WgpuElement,
kernel::{KernelSettings, SourceTemplate, StaticKernel},
matmul_tile_2d,
tensor::WgpuTensor,
};
matmul_tile_2d!(
MatmulTiling2DContinuous,
"../../../template/matmul/blocktiling_2d/continuous.wgsl"
);

View File

@ -0,0 +1,12 @@
use super::base::{matmul_tiling_2d_launch, register_template};
use crate::{
element::WgpuElement,
kernel::{KernelSettings, SourceTemplate, StaticKernel},
matmul_tile_2d,
tensor::WgpuTensor,
};
matmul_tile_2d!(
MatmulTiling2DContinuousVectorized,
"../../../template/matmul/blocktiling_2d/continuous_vectorized.wgsl"
);

View File

@ -0,0 +1,13 @@
mod base;
mod padding;
/// Loading is done in a continuous manner
pub mod continuous;
/// Loading is done in a continuous manner. lhs is transposed
pub mod continuous_vectorized;
/// Loading is done in a tile manner
pub mod tile;
/// Loading is done in a tile manner. lhs is transposed
pub mod tile_vectorized;
pub use tile_vectorized::*;

View File

@ -0,0 +1,241 @@
use std::ops::Range;
use burn_tensor::Shape;
use crate::{
element::WgpuElement,
kernel::{slice, slice_assign},
tensor::WgpuTensor,
};
use super::base::empty_from_context;
/// Pads tensor with zeros to make tensor number of rows and columns
/// divisible by some quantity.
/// For instance tensor of shape [1000, 1000] with divisors 64 and 64
/// will be padded to [1024, 1024] with the last 24 elements being zeros
pub(super) fn pad_round<E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
row_divisor: usize,
col_divisor: usize,
) -> WgpuTensor<E, D> {
let row_modulo = tensor.shape.dims[D - 2] % row_divisor;
let col_modulo = tensor.shape.dims[D - 1] % col_divisor;
if row_modulo == 0 && col_modulo == 0 {
return tensor;
}
let mut padded_shape = Vec::with_capacity(D);
for i in 0..D - 2 {
padded_shape.push(tensor.shape.dims[i]);
}
padded_shape.push(tensor.shape.dims[D - 2] - row_modulo + row_divisor);
padded_shape.push(tensor.shape.dims[D - 1] - col_modulo + col_divisor);
padding::<E, D>(tensor, padded_shape.into())
}
/// Pads tensor by adding zeros when padded dim is larger than tensor dim
fn padding<E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
padded_shape: Shape<D>,
) -> WgpuTensor<E, D> {
let ranges = padded_shape
.dims
.iter()
.map(|dim| 0..*dim)
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
slice_assign::<E, D, D>(
empty_from_context(tensor.context.clone(), &padded_shape),
ranges,
tensor,
)
}
/// Crops tensor by deleting values when cropped dim is smaller than tensor dim
pub(super) fn crop<E: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
cropped_shape: Shape<D>,
) -> WgpuTensor<E, D> {
let ranges = cropped_shape
.dims
.iter()
.map(|dim| 0..*dim)
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
slice::<E, D, D>(tensor, ranges)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::TestTensor;
#[test]
fn padding_already_round_should_have_same_shape() {
let row = 10;
let row_divisor = 5;
let col = 12;
let col_divisor = 3;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let expected_shape = [row, col].into();
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
assert!(padded.shape == expected_shape);
}
#[test]
fn padding_already_round_should_have_same_values() {
let row = 10;
let row_divisor = 5;
let col = 12;
let col_divisor = 3;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor);
let padded = TestTensor::from_primitive(padded);
padded.into_data().assert_approx_eq(&tensor.into_data(), 3);
}
#[test]
fn padding_not_round_should_have_rounded_shape() {
let row = 10;
let row_divisor = 6;
let col = 12;
let col_divisor = 5;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let expected_shape = [12, 15].into();
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
assert!(padded.shape == expected_shape);
}
#[test]
fn padding_not_round_should_have_same_values() {
let row = 10;
let row_divisor = 6;
let col = 12;
let col_divisor = 5;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor);
let padded = TestTensor::from_primitive(padded).to_data();
let tensor = tensor.into_data();
for i in 0..row {
for j in 0..col {
assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]);
}
}
}
#[test]
fn padding_not_round_should_have_zero_padding() {
let row = 10;
let row_divisor = 6;
let col = 12;
let col_divisor = 5;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
let padded = TestTensor::from_primitive(padded).to_data();
// check right of matrix
for i in 0..row {
for j in col..15 {
assert!(padded.value[i * 15 + j] == 0.0);
}
}
// check below matrix, including bottom right
for i in row..12 {
for j in 0..15 {
assert!(padded.value[i * 15 + j] == 0.0);
}
}
}
#[test]
fn padding_works_with_batch() {
let row = 10;
let row_divisor = 4;
let col = 12;
let col_divisor = 5;
let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default);
let expected_shape = [2, 3, 12, 15].into();
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
assert!(padded.shape == expected_shape);
}
#[test]
fn crop_same_shape_should_be_unchanged_shape() {
let row = 10;
let col = 12;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let expected_shape = [row, col].into();
let unpadded = crop(tensor.into_primitive(), [row, col].into());
assert!(unpadded.shape == expected_shape);
}
#[test]
fn crop_same_shape_should_have_unchanged_values() {
let row = 10;
let col = 12;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let unpadded = crop(tensor.clone().into_primitive(), [row, col].into());
let unpadded = TestTensor::from_primitive(unpadded).to_data();
let tensor = tensor.into_data();
for i in 0..row {
for j in 0..col {
assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]);
}
}
}
#[test]
fn crop_should_decrease_shape() {
let row = 10;
let keep_rows = 8;
let col = 12;
let keep_cols = 10;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let expected_shape = [keep_rows, keep_cols].into();
let unpadded = crop(tensor.into_primitive(), [keep_rows, keep_cols].into());
assert!(unpadded.shape == expected_shape);
}
#[test]
fn crop_should_keep_same_values() {
let row = 4;
let keep_rows = 3;
let col = 4;
let keep_cols = 3;
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
let unpadded = crop(
tensor.clone().into_primitive(),
[keep_rows, keep_cols].into(),
);
let unpadded = TestTensor::from_primitive(unpadded).to_data();
let tensor = tensor.into_data();
println!("{:?}\n {:?}", unpadded, tensor);
for i in 0..keep_rows {
for j in 0..keep_cols {
assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]);
}
}
}
}

View File

@ -0,0 +1,13 @@
use crate::{
element::WgpuElement,
kernel::{KernelSettings, SourceTemplate, StaticKernel},
matmul_tile_2d,
tensor::WgpuTensor,
};
use super::base::{matmul_tiling_2d_launch, register_template};
matmul_tile_2d!(
MatmulTiling2DTile,
"../../../template/matmul/blocktiling_2d/tile.wgsl"
);

View File

@ -0,0 +1,12 @@
use super::base::{matmul_tiling_2d_launch, register_template};
use crate::{
element::WgpuElement,
kernel::{KernelSettings, SourceTemplate, StaticKernel},
matmul_tile_2d,
tensor::WgpuTensor,
};
matmul_tile_2d!(
MatmulTiling2DTileVectorized,
"../../../template/matmul/blocktiling_2d/tile_vectorized.wgsl"
);

View File

@ -0,0 +1,48 @@
use crate::{element::WgpuElement, tensor::WgpuTensor};
use burn_tensor::Shape;
pub(crate) fn shape_out<E: WgpuElement, const D: usize>(
lhs: &WgpuTensor<E, D>,
rhs: &WgpuTensor<E, D>,
) -> Shape<D> {
let mut shape_out = [0; D];
lhs.shape
.dims
.iter()
.zip(rhs.shape.dims.iter())
.enumerate()
.for_each(|(index, (dim_lhs, dim_rhs))| {
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
});
shape_out[D - 2] = lhs.shape.dims[D - 2];
shape_out[D - 1] = rhs.shape.dims[D - 1];
Shape::new(shape_out)
}
#[cfg(test)]
pub(crate) mod tests {
use crate::tensor::WgpuTensor;
use crate::tests::{ReferenceTensor, TestTensor};
use burn_tensor::Shape;
pub(crate) fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
where
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
S: Into<Shape<D>>,
{
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
let x_wgpu = TestTensor::from_data(x.to_data());
let y_wgpu = TestTensor::from_data(y.to_data());
let z_reference = x.matmul(y);
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
let z = TestTensor::from_primitive(z);
std::println!("{z}");
std::println!("{z_reference}");
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
}
}

View File

@ -4,7 +4,6 @@ mod cat;
mod comparison;
mod index;
mod mask;
mod matmul;
mod reduction;
mod source;
mod unary;
@ -12,11 +11,13 @@ mod unary_scalar;
pub use base::*;
pub use binary_elemwise::*;
pub use matmul::*;
pub use source::*;
pub use unary::*;
pub use unary_scalar::*;
/// Matmul kernels
pub mod matmul;
pub(crate) use cat::*;
pub(crate) use comparison::*;
pub(crate) use index::*;

View File

@ -38,7 +38,9 @@ mod tests {
pub type TestBackend = WgpuBackend<GraphicsApi, f32, i32>;
pub type ReferenceBackend = burn_ndarray::NdArrayBackend<f32>;
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
burn_tensor::testgen_add!();

View File

@ -1,7 +1,6 @@
use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
use crate::kernel::{
self, matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
unary_scalar_inplace_default,
self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default,
};
use crate::unary_scalar_inplace;
use crate::{
@ -140,7 +139,7 @@ where
let lhs = kernel::into_continuous(lhs);
let rhs = kernel::into_continuous(rhs);
matmul_tiling_2d_default::<FloatElem<Self>, D>(lhs, rhs)
kernel::matmul::tile_vectorized::matmul_tiling_2d_default(lhs, rhs)
}
fn swap_dims<const D: usize>(

View File

@ -0,0 +1,139 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K = {{bm_x_bk}}u;
const B_K_X_B_N = {{bk_x_bn}}u;
const T_M = {{t_m}}u;
const T_N = {{t_n}}u;
const T_M_X_T_N = {{tm_x_tn}}u;
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * K;
var offset_rhs: u32 = skip_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: array<{{ elem }}, T_M>;
var register_N: array<{{ elem }}, T_N>;
let thread_offset = local_idx * T_M_X_T_N;
for (var k = 0u; k < K; k += B_K) {
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
let lhs_sm_position = thread_offset + load_index;
let block_row = lhs_sm_position / B_K;
let block_col = lhs_sm_position % B_K;
let lhs_position = offset_lhs + k + block_row * K + block_col;
if block_row < B_M {
shared_lhs[lhs_sm_position] = lhs[lhs_position];
}
}
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
let rhs_sm_position = thread_offset + load_index;
let block_row = rhs_sm_position / B_N;
let block_col = rhs_sm_position % B_N;
let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col;
if block_row < B_K {
shared_rhs[rhs_sm_position] = rhs[rhs_position];
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
register_M[tile_index] = shared_lhs[lhs_sm_position];
}
// Load a subrow of values from rhs
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
register_N[tile_index] = shared_rhs[rhs_sm_position];
}
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let current_row = row + res_idx_M;
let current_col = col + res_idx_N;
// Check that we are within the bounds of output matrix
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + current_row * n_cols + current_col;
output[output_position] = results[result_position];
}
}
}

View File

@ -0,0 +1,138 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K = {{bm_x_bk}}u;
const B_K_X_B_N = {{bk_x_bn}}u;
const T_M = {{t_m}}u;
const T_N = {{t_n}}u;
const T_M_X_T_N = {{tm_x_tn}}u;
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * K;
var offset_rhs: u32 = skip_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: array<{{ elem }}, T_M>;
var register_N: array<{{ elem }}, T_N>;
let thread_offset = local_idx * T_M_X_T_N;
for (var k = 0u; k < K; k += B_K) {
// tile: let lhs_sm_position = current_row * B_K + current_col;
// tile_vec: let lhs_sm_position = current_row + current_col * B_M;
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
let lhs_sm_position = thread_offset + load_index;
let block_row = lhs_sm_position % B_M;
let block_col = lhs_sm_position / B_M;
let lhs_position = offset_lhs + k + block_row * K + block_col;
if block_col < B_K {
shared_lhs[lhs_sm_position] = lhs[lhs_position];
}
}
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
let rhs_sm_position = thread_offset + load_index;
let block_row = rhs_sm_position / B_N;
let block_col = rhs_sm_position % B_N;
let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col;
if block_row < B_K {
shared_rhs[rhs_sm_position] = rhs[rhs_position];
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
let lhs_sm_position = thread_row + tile_index + dot_index * B_M;
register_M[tile_index] = shared_lhs[lhs_sm_position];
}
// Load a subrow of values from rhs
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
register_N[tile_index] = shared_rhs[rhs_sm_position];
}
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
output[output_position] = results[result_position];
}
}
}

View File

@ -68,35 +68,31 @@ fn main(
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// In case B_M % T_M != 0 or B_N % T_N != 0
// A thread must not read out of its block
let actual_T_M = min(B_M - thread_row, T_M);
let actual_T_N = min(B_N - thread_col, T_N);
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: array<{{ elem }}, T_M>;
var register_N: array<{{ elem }}, T_N>;
let thread_offset = local_idx * T_M_X_T_N;
for (var k = 0u; k < K; k += B_K) {
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
let sm_limit = min(B_K, K - k);
// Load data into shared memories
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
for (var i = 0u; i < actual_T_M; i++) {
for (var j = 0u; j < actual_T_N; j++) {
for (var i = 0u; i < T_M; i++) {
for (var j = 0u; j < T_N; j++) {
let current_row = thread_row + i;
let current_col = thread_col + j;
if current_col < sm_limit {
if current_col < B_K {
let lhs_sm_position = current_row * B_K + current_col;
let lhs_position = offset_lhs + k + current_row * K + current_col;
shared_lhs[lhs_sm_position] = lhs[lhs_position];
}
if current_row < sm_limit {
if current_row < B_K {
let rhs_sm_position = current_row * B_N + current_col;
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
shared_rhs[rhs_sm_position] = rhs[rhs_position];
@ -104,26 +100,27 @@ fn main(
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < sm_limit; dot_index++) {
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
for (var tile_index = 0u; tile_index < actual_T_M; tile_index++) {
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
register_M[tile_index] = shared_lhs[lhs_sm_position];
}
// Load a subrow of values from rhs
for (var tile_index = 0u; tile_index < actual_T_N; tile_index++) {
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
register_N[tile_index] = shared_rhs[rhs_sm_position];
}
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
results[res_idx_M * actual_T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
@ -133,16 +130,11 @@ fn main(
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
let current_row = row + res_idx_M;
let current_col = col + res_idx_N;
// Check that we are within the bounds of output matrix
if current_row < n_rows && current_col < n_cols {
let result_position = res_idx_M * actual_T_N + res_idx_N;
let output_position = offset_output + current_row * n_cols + current_col;
output[output_position] = results[result_position];
}
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
output[output_position] = results[result_position];
}
}
}

View File

@ -0,0 +1,140 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K = {{bm_x_bk}}u;
const B_K_X_B_N = {{bk_x_bn}}u;
const T_M = {{t_m}}u;
const T_N = {{t_n}}u;
const T_M_X_T_N = {{tm_x_tn}}u;
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * K;
var offset_rhs: u32 = skip_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: array<{{ elem }}, T_M>;
var register_N: array<{{ elem }}, T_N>;
let thread_offset = local_idx * T_M_X_T_N;
for (var k = 0u; k < K; k += B_K) {
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
// Load data into shared memories
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
for (var i = 0u; i < T_M; i++) {
for (var j = 0u; j < T_N; j++) {
let current_row = thread_row + i;
let current_col = thread_col + j;
if current_col < B_K {
let lhs_sm_position = current_row + current_col * B_M;
let lhs_position = offset_lhs + k + current_row * K + current_col;
shared_lhs[lhs_sm_position] = lhs[lhs_position];
}
if current_row < B_K {
let rhs_sm_position = current_row * B_N + current_col;
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
shared_rhs[rhs_sm_position] = rhs[rhs_position];
}
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
let lhs_sm_position = thread_row + tile_index + dot_index * B_M;
register_M[tile_index] = shared_lhs[lhs_sm_position];
}
// Load a subrow of values from rhs
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
register_N[tile_index] = shared_rhs[rhs_sm_position];
}
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
output[output_position] = results[result_position];
}
}
}