mirror of https://github.com/tracel-ai/burn.git
Feat/matmul/faster (#479)
This commit is contained in:
parent
261aa952c0
commit
513b9281c2
|
@ -1,11 +1,13 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_wgpu::{
|
||||
benchmark::Benchmark,
|
||||
kernel::{matmul_mem_coalescing_default, matmul_naive_default, matmul_tiling_2d_default},
|
||||
kernel::matmul::{
|
||||
continuous, continuous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
|
||||
tile, tile_vectorized,
|
||||
},
|
||||
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
trait MatmulFunction<B: Backend, const D: usize> {
|
||||
fn run(lhs: Tensor<B, D>, rhs: Tensor<B, D>) -> Tensor<B, D>;
|
||||
|
@ -37,6 +39,10 @@ where
|
|||
)
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn execute(&self, (lhs, rhs): Self::Args) {
|
||||
for _ in 0..self.num_repeats {
|
||||
F::run(lhs.clone(), rhs.clone());
|
||||
|
@ -51,71 +57,69 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
struct Tiling2DMatmul;
|
||||
macro_rules! benchmark {
|
||||
($name:ident, $func:expr) => {
|
||||
struct $name;
|
||||
|
||||
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
|
||||
for Tiling2DMatmul
|
||||
{
|
||||
fn run(
|
||||
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
|
||||
Tensor::from_primitive(matmul_tiling_2d_default(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
))
|
||||
}
|
||||
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D> for $name {
|
||||
fn run(
|
||||
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
|
||||
Tensor::from_primitive($func(lhs.into_primitive(), rhs.into_primitive()))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
struct NaiveMatmul;
|
||||
|
||||
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D> for NaiveMatmul {
|
||||
fn run(
|
||||
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
|
||||
Tensor::from_primitive(matmul_naive_default(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
struct MemCoalescingMatmul;
|
||||
|
||||
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D>
|
||||
for MemCoalescingMatmul
|
||||
{
|
||||
fn run(
|
||||
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
|
||||
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
|
||||
Tensor::from_primitive(matmul_mem_coalescing_default(
|
||||
lhs.into_primitive(),
|
||||
rhs.into_primitive(),
|
||||
))
|
||||
}
|
||||
}
|
||||
benchmark!(NaiveMatmul, matmul_naive_default);
|
||||
benchmark!(MemCoalescingMatmul, matmul_mem_coalescing_default);
|
||||
benchmark!(
|
||||
Tiling2DMatmulContinuous,
|
||||
continuous::matmul_tiling_2d_default
|
||||
);
|
||||
benchmark!(Tiling2DMatmulTile, tile::matmul_tiling_2d_default);
|
||||
benchmark!(
|
||||
Tiling2DMatmulTileVectorized,
|
||||
tile_vectorized::matmul_tiling_2d_default
|
||||
);
|
||||
benchmark!(
|
||||
Tiling2DMatmulContinuousVectorized,
|
||||
continuous_vectorized::matmul_tiling_2d_default
|
||||
);
|
||||
|
||||
fn main() {
|
||||
let batch_size = 32;
|
||||
let matrix_size = 128;
|
||||
run_benchmark!(MatmulBenchmark::<NaiveMatmul, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats: 10,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
let num_repeats = 3;
|
||||
let batch_size = 3;
|
||||
let matrix_size = 1000;
|
||||
run_benchmark!(MatmulBenchmark::<MemCoalescingMatmul, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats: 10,
|
||||
num_repeats,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
run_benchmark!(MatmulBenchmark::<Tiling2DMatmul, 3> {
|
||||
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContinuous, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats: 10,
|
||||
num_repeats,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContinuousVectorized, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTile, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTileVectorized, 3> {
|
||||
shape_lhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
shape_rhs: [batch_size, matrix_size, matrix_size].into(),
|
||||
num_repeats,
|
||||
matmul: PhantomData::default()
|
||||
});
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::utils::shape_out;
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
element::WgpuElement,
|
||||
|
@ -5,11 +6,10 @@ use crate::{
|
|||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulMemCoalescingRaw,
|
||||
"../../template/matmul_mem_coalescing.wgsl"
|
||||
"../../template/matmul/mem_coalescing.wgsl"
|
||||
);
|
||||
|
||||
struct MatmulMemCoalescing<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
|
||||
|
@ -44,21 +44,9 @@ pub fn matmul_mem_coalescing<
|
|||
) -> WgpuTensor<E, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let shape_out = shape_out(&lhs, &rhs);
|
||||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
shape_out[D - 2] = num_rows;
|
||||
shape_out[D - 1] = num_cols;
|
||||
let shape_out = Shape::new(shape_out);
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
|
@ -103,10 +91,7 @@ pub fn matmul_mem_coalescing<
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::TestTensor;
|
||||
|
||||
pub type ReferenceTensor<const D: usize> =
|
||||
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
|
||||
use crate::kernel::matmul::utils::tests::same_as_reference;
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_mem_coalescing_straightforward() {
|
||||
|
@ -167,25 +152,4 @@ mod tests {
|
|||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(func, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
|
||||
where
|
||||
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
|
||||
S: Into<Shape<D>>,
|
||||
{
|
||||
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
|
||||
let x_wgpu = TestTensor::from_data(x.to_data());
|
||||
let y_wgpu = TestTensor::from_data(y.to_data());
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
|
||||
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
|
||||
let z = TestTensor::from_primitive(z);
|
||||
|
||||
println!("{z}");
|
||||
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
pub(crate) mod utils;
|
||||
|
||||
mod mem_coalescing;
|
||||
mod naive;
|
||||
mod tiling2d;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::utils::shape_out;
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
element::WgpuElement,
|
||||
|
@ -5,9 +6,8 @@ use crate::{
|
|||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul_naive.wgsl");
|
||||
kernel_wgsl!(MatmulNaiveRaw, "../../template/matmul/naive.wgsl");
|
||||
|
||||
struct MatmulNaive<const WORKGROUP_SIZE_X: usize, const WORKGROUP_SIZE_Y: usize>;
|
||||
|
||||
|
@ -41,21 +41,10 @@ pub fn matmul_naive<
|
|||
) -> WgpuTensor<E, D> {
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
let shape_out = shape_out(&lhs, &rhs);
|
||||
|
||||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
shape_out[D - 2] = num_rows;
|
||||
shape_out[D - 1] = num_cols;
|
||||
let shape_out = Shape::new(shape_out);
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
|
@ -100,10 +89,7 @@ pub fn matmul_naive<
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::TestTensor;
|
||||
|
||||
pub type ReferenceTensor<const D: usize> =
|
||||
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
|
||||
use crate::kernel::matmul::utils::tests::same_as_reference;
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_naive_straightforward() {
|
||||
|
@ -162,25 +148,4 @@ mod tests {
|
|||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(func, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
|
||||
where
|
||||
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
|
||||
S: Into<Shape<D>>,
|
||||
{
|
||||
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
|
||||
let x_wgpu = TestTensor::from_data(x.to_data());
|
||||
let y_wgpu = TestTensor::from_data(y.to_data());
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
|
||||
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
|
||||
let z = TestTensor::from_primitive(z);
|
||||
|
||||
println!("{z}");
|
||||
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,330 +0,0 @@
|
|||
use std::cmp::{max, min};
|
||||
|
||||
use crate::{
|
||||
context::WorkGroup,
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, KernelSettings, SourceTemplate, StaticKernel},
|
||||
kernel_wgsl,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulTiling2DRaw,
|
||||
"../../template/matmul_blocktiling_2d.wgsl"
|
||||
);
|
||||
|
||||
struct MatmulTiling2D<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>;
|
||||
|
||||
impl<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
> StaticKernel for MatmulTiling2D<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
|
||||
{
|
||||
fn source_template() -> SourceTemplate {
|
||||
MatmulTiling2DRaw::source_template()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk", (B_M * B_K).to_string())
|
||||
.register("bk_x_bn", (B_K * B_N).to_string())
|
||||
.register("t_m", T_M.to_string())
|
||||
.register("t_n", T_N.to_string())
|
||||
.register("tm_x_tn", (T_M * T_N).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2D algorithm with default parameters
|
||||
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
|
||||
// Block size along dim M
|
||||
const B_M: usize = 128;
|
||||
// // Block size along dim N
|
||||
const B_N: usize = 128;
|
||||
// // Block size along dim K
|
||||
const B_K: usize = 8;
|
||||
// // Tiling size along dim M
|
||||
const T_M: usize = 8;
|
||||
// // Tiling size along dim N
|
||||
const T_N: usize = 8;
|
||||
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
|
||||
const WORKGROUP_SIZE_X: usize = 16;
|
||||
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
|
||||
const WORKGROUP_SIZE_Y: usize = 16;
|
||||
|
||||
matmul_tiling_2d::<E, D, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(lhs, rhs)
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2D algorithm with custom parameters
|
||||
pub fn matmul_tiling_2d<
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
assert!(B_K <= min(B_M, B_N), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
|
||||
assert!(B_K * max(B_M, B_N) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
|
||||
assert!(
|
||||
WORKGROUP_SIZE_X == f32::ceil(B_M as f32 / T_M as f32) as usize,
|
||||
"Workgroup size x must equal ceil(B_M / T_M)"
|
||||
);
|
||||
assert!(
|
||||
WORKGROUP_SIZE_Y == f32::ceil(B_N as f32 / T_N as f32) as usize,
|
||||
"Workgroup size y must equal ceil(B_N / T_N)"
|
||||
);
|
||||
lhs.assert_is_on_same_device(&rhs);
|
||||
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
|
||||
let num_rows = lhs.shape.dims[D - 2];
|
||||
let num_cols = rhs.shape.dims[D - 1];
|
||||
shape_out[D - 2] = num_rows;
|
||||
shape_out[D - 1] = num_cols;
|
||||
let shape_out = Shape::new(shape_out);
|
||||
|
||||
let buffer = lhs
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.context.clone(), shape_out, buffer);
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / B_M as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / B_N as f32) as u32;
|
||||
|
||||
let kernel = lhs.context.compile_static::<KernelSettings<
|
||||
MatmulTiling2D<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
1,
|
||||
>>();
|
||||
|
||||
let info = build_info(&[&lhs, &rhs, &output]);
|
||||
|
||||
let info_buffers = lhs
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output.shape.dims[i];
|
||||
}
|
||||
|
||||
let workgroup = WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32);
|
||||
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::TestTensor;
|
||||
|
||||
pub type ReferenceTensor<const D: usize> =
|
||||
burn_tensor::Tensor<burn_ndarray::NdArrayBackend<f32>, D>;
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() {
|
||||
test_with_params::<128, 128, 16, 8, 8, 16, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_m_not_equals_n() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 8, 3, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_smaller_than_m_n() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 3, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_larger_than_m_n() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(8, 48, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_t_divides_b_unevenly() {
|
||||
test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_small_parameters() {
|
||||
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_bm_not_equals_bn() {
|
||||
test_with_params::<32, 128, 8, 8, 8, 4, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multibatch_1_dim() {
|
||||
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multibatch_2_dims() {
|
||||
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 8, 8, 3, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_memory_busted_should_panic() {
|
||||
test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() {
|
||||
test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() {
|
||||
test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() {
|
||||
test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multiple_blocks() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(32, 32, 32, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_bigger_than_bk() {
|
||||
test_with_params::<128, 128, 8, 8, 8, 16, 16>(8, 10, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(31, 23, 17, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_large_parameters() {
|
||||
test_with_params::<256, 256, 16, 16, 16, 16, 16>(40, 40, 40, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_shapes_slightly_larger_than_blocks() {
|
||||
test_with_params::<32, 32, 8, 8, 8, 4, 4>(40, 40, 30, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() {
|
||||
test_with_params::<16, 16, 8, 8, 8, 2, 2>(50, 50, 50, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_tm_larger_than_bm() {
|
||||
test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_tn_larger_than_bn() {
|
||||
test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_uneven_parameters() {
|
||||
test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_uneven_parameters_2() {
|
||||
test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1);
|
||||
}
|
||||
|
||||
fn test_with_params<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
batch_1: usize,
|
||||
batch_2: usize,
|
||||
) {
|
||||
let func = |lhs, rhs| {
|
||||
matmul_tiling_2d::<f32, 4, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
|
||||
lhs, rhs,
|
||||
)
|
||||
};
|
||||
let shape_lhs = [batch_1, batch_2, m, k];
|
||||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(func, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
|
||||
where
|
||||
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
|
||||
S: Into<Shape<D>>,
|
||||
{
|
||||
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
|
||||
let x_wgpu = TestTensor::from_data(x.to_data());
|
||||
let y_wgpu = TestTensor::from_data(y.to_data());
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
|
||||
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
|
||||
let z = TestTensor::from_primitive(z);
|
||||
|
||||
println!("{z}");
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,415 @@
|
|||
use crate::{
|
||||
context::{Context, WorkGroup},
|
||||
element::WgpuElement,
|
||||
kernel::{build_info, matmul::utils::shape_out, SourceTemplate},
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::{
|
||||
cmp::{max, min},
|
||||
sync::Arc,
|
||||
};
|
||||
use wgpu::ComputePipeline;
|
||||
|
||||
use super::padding::{crop, pad_round};
|
||||
|
||||
const MAX_SHARED_MEMORY_SIZE: usize = 8192;
|
||||
|
||||
pub(super) fn empty_from_context<E: WgpuElement, const D: usize>(
|
||||
context: Arc<Context>,
|
||||
shape: &Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WgpuTensor::new(context, shape.clone(), buffer)
|
||||
}
|
||||
|
||||
/// Create a source template for tile 2d matmul.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! matmul_tile_2d {
|
||||
(
|
||||
$struct:ident,
|
||||
$file:expr
|
||||
) => {
|
||||
matmul_tile_2d!(
|
||||
$struct,
|
||||
$file,
|
||||
B_M 64,
|
||||
B_N 64,
|
||||
B_K 32,
|
||||
T_M 4,
|
||||
T_N 4
|
||||
);
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
$file:expr,
|
||||
B_M $bm:expr,
|
||||
B_N $bn:expr,
|
||||
B_K $bk:expr,
|
||||
T_M $tm:expr,
|
||||
T_N $tn:expr
|
||||
) => {
|
||||
struct $struct<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>;
|
||||
|
||||
impl<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
> StaticKernel
|
||||
for $struct<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>
|
||||
{
|
||||
fn source_template() -> SourceTemplate {
|
||||
kernel_wgsl!(Raw, $file);
|
||||
|
||||
register_template::<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
|
||||
Raw::source_template(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2D algorithm with default parameters
|
||||
pub fn matmul_tiling_2d_default<E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
// Suppose a matmul of m1 of size [M, K] with m2 of size [K, N]
|
||||
// Block size along dim M
|
||||
const B_M: usize = $bm;
|
||||
// // Block size along dim N
|
||||
const B_N: usize = $bn;
|
||||
// // Block size along dim K
|
||||
const B_K: usize = $bk;
|
||||
// // Tiling size along dim M
|
||||
const T_M: usize = $tm;
|
||||
// // Tiling size along dim N
|
||||
const T_N: usize = $tn;
|
||||
// WORKGROUP_SIZE_X = ceil(B_M / T_M)
|
||||
const WORKGROUP_SIZE_X: usize = B_M / T_M;
|
||||
// WORKGROUP_SIZE_Y = ceil(B_N / T_N)
|
||||
const WORKGROUP_SIZE_Y: usize = B_N / T_N;
|
||||
|
||||
matmul_tiling_2d::<E, D, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
|
||||
lhs, rhs,
|
||||
)
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2D algorithm with custom parameters
|
||||
pub fn matmul_tiling_2d<
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let kernel = lhs.context.compile_static::<KernelSettings<
|
||||
$struct<B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>,
|
||||
E,
|
||||
i32,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
1,
|
||||
>>();
|
||||
matmul_tiling_2d_launch::<
|
||||
E,
|
||||
D,
|
||||
B_M,
|
||||
B_N,
|
||||
B_K,
|
||||
T_M,
|
||||
T_N,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
>(lhs, rhs, kernel)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use $crate::kernel::matmul::utils::tests::same_as_reference;
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_large_blocks() {
|
||||
test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_shapes_smaller_than_blocks() {
|
||||
test_with_params::<64, 64, 8, 4, 4, 16, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_m_not_equals_n() {
|
||||
test_with_params::<16, 16, 8, 2, 2, 8, 8>(16, 8, 16, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_smaller_than_m_n() {
|
||||
test_with_params::<16, 16, 4, 2, 2, 8, 8>(16, 4, 16, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_larger_than_m_n() {
|
||||
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 48, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_t_divides_b_unevenly_should_panic() {
|
||||
test_with_params::<128, 128, 8, 7, 11, 19, 12>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_bm_not_equals_bn() {
|
||||
test_with_params::<8, 16, 8, 2, 4, 4, 4>(8, 8, 16, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multibatch_1_dim() {
|
||||
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multibatch_2_dims() {
|
||||
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 8, 8, 3, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_memory_busted_should_panic() {
|
||||
test_with_params::<128, 128, 128, 8, 8, 16, 16>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_bk_larger_than_bm_should_panic() {
|
||||
test_with_params::<64, 64, 128, 8, 8, 8, 8>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_workgroup_x_wrong_should_panic() {
|
||||
test_with_params::<128, 128, 16, 8, 8, 16, 8>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_workgroup_y_wrong_should_panic() {
|
||||
test_with_params::<128, 128, 16, 8, 8, 8, 7>(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_multiple_blocks() {
|
||||
test_with_params::<16, 16, 8, 2, 2, 8, 8>(32, 32, 32, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_k_bigger_than_bk() {
|
||||
test_with_params::<8, 8, 8, 2, 2, 4, 4>(8, 16, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_blocks_divide_shapes_unevenly() {
|
||||
test_with_params::<16, 16, 8, 2, 2, 8, 8>(31, 23, 17, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_tiling_2d_shapes_way_larger_than_blocks() {
|
||||
test_with_params::<16, 16, 8, 2, 2, 8, 8>(48, 48, 48, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_tm_larger_than_bm_should_panic() {
|
||||
test_with_params::<2, 2, 2, 3, 2, 1, 1>(5, 5, 5, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_tn_larger_than_bn_should_panic() {
|
||||
test_with_params::<2, 2, 2, 2, 3, 1, 1>(5, 5, 5, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_uneven_parameters_should_panic() {
|
||||
test_with_params::<17, 15, 11, 13, 7, 2, 3>(24, 24, 24, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn test_matmul_tiling_2d_uneven_parameters_2_should_panic() {
|
||||
test_with_params::<11, 14, 10, 7, 17, 2, 1>(10, 24, 17, 1, 1);
|
||||
}
|
||||
|
||||
fn test_with_params<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
m: usize,
|
||||
k: usize,
|
||||
n: usize,
|
||||
batch_1: usize,
|
||||
batch_2: usize,
|
||||
) {
|
||||
let func = |lhs, rhs| {
|
||||
matmul_tiling_2d::<f32, 4, B_M, B_N, B_K, T_M, T_N, WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y>(
|
||||
lhs, rhs,
|
||||
)
|
||||
};
|
||||
let shape_lhs = [batch_1, batch_2, m, k];
|
||||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(func, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(super) fn register_template<
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
template: SourceTemplate,
|
||||
) -> SourceTemplate {
|
||||
template
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk", (B_M * B_K).to_string())
|
||||
.register("bk_x_bn", (B_K * B_N).to_string())
|
||||
.register("t_m", T_M.to_string())
|
||||
.register("t_n", T_N.to_string())
|
||||
.register("tm_x_tn", (T_M * T_N).to_string())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn matmul_parameter_assertions<E: WgpuElement, const D: usize>(
|
||||
b_m: usize,
|
||||
b_n: usize,
|
||||
b_k: usize,
|
||||
t_m: usize,
|
||||
t_n: usize,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
lhs: &WgpuTensor<E, D>,
|
||||
rhs: &WgpuTensor<E, D>,
|
||||
) {
|
||||
assert!(b_k <= min(b_m, b_n), "B_K must be smaller than both B_M and B_M, otherwise there won't be enough threads to fill shared memory. ");
|
||||
assert!(b_k * max(b_m, b_n) <= MAX_SHARED_MEMORY_SIZE, "B_K x B_M and B_K x B_N must be smaller or equal than 8192, otherwise shared memory limit will be busted. ");
|
||||
assert!(
|
||||
b_m % t_m == 0 && b_n % t_n == 0,
|
||||
"T_M must divide B_M in this version"
|
||||
);
|
||||
assert!(
|
||||
workgroup_size_x == b_m / t_m,
|
||||
"Workgroup size x must equal B_M / T_M"
|
||||
);
|
||||
assert!(
|
||||
workgroup_size_y == b_n / t_n,
|
||||
"Workgroup size y must equal B_N / T_N"
|
||||
);
|
||||
lhs.assert_is_on_same_device(rhs);
|
||||
}
|
||||
|
||||
pub(super) fn make_workgroup<const D: usize>(
|
||||
output_shape: Shape<D>,
|
||||
b_m: usize,
|
||||
b_n: usize,
|
||||
) -> WorkGroup {
|
||||
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / b_m as f32) as u32;
|
||||
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / b_n as f32) as u32;
|
||||
let mut num_blocks_z = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_blocks_z *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
|
||||
}
|
||||
|
||||
pub(super) fn make_info_buffers<E: WgpuElement, const D: usize>(
|
||||
lhs: &WgpuTensor<E, D>,
|
||||
rhs: &WgpuTensor<E, D>,
|
||||
output: &WgpuTensor<E, D>,
|
||||
) -> Arc<wgpu::Buffer> {
|
||||
let info = build_info(&[lhs, rhs, output]);
|
||||
rhs.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
||||
pub(super) fn matmul_tiling_2d_launch<
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const B_M: usize,
|
||||
const B_N: usize,
|
||||
const B_K: usize,
|
||||
const T_M: usize,
|
||||
const T_N: usize,
|
||||
const WORKGROUP_SIZE_X: usize,
|
||||
const WORKGROUP_SIZE_Y: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
kernel: Arc<ComputePipeline>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
matmul_parameter_assertions::<E, D>(
|
||||
B_M,
|
||||
B_N,
|
||||
B_K,
|
||||
T_M,
|
||||
T_N,
|
||||
WORKGROUP_SIZE_X,
|
||||
WORKGROUP_SIZE_Y,
|
||||
&lhs,
|
||||
&rhs,
|
||||
);
|
||||
|
||||
let final_output_shape = shape_out(&lhs, &rhs);
|
||||
let lhs = pad_round(lhs, B_M, B_K);
|
||||
let rhs = pad_round(rhs, B_K, B_N);
|
||||
let rounded_output_shape = shape_out(&lhs, &rhs);
|
||||
|
||||
let output = empty_from_context::<E, D>(rhs.context.clone(), &rounded_output_shape);
|
||||
|
||||
let workgroup = make_workgroup(rounded_output_shape, B_M, B_N);
|
||||
let info_buffers = make_info_buffers(&lhs, &rhs, &output);
|
||||
|
||||
lhs.context.execute(
|
||||
workgroup,
|
||||
kernel,
|
||||
&[&lhs.buffer, &rhs.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
crop(output, final_output_shape)
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
use super::base::{matmul_tiling_2d_launch, register_template};
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{KernelSettings, SourceTemplate, StaticKernel},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
matmul_tile_2d!(
|
||||
MatmulTiling2DContinuous,
|
||||
"../../../template/matmul/blocktiling_2d/continuous.wgsl"
|
||||
);
|
|
@ -0,0 +1,12 @@
|
|||
use super::base::{matmul_tiling_2d_launch, register_template};
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{KernelSettings, SourceTemplate, StaticKernel},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
matmul_tile_2d!(
|
||||
MatmulTiling2DContinuousVectorized,
|
||||
"../../../template/matmul/blocktiling_2d/continuous_vectorized.wgsl"
|
||||
);
|
|
@ -0,0 +1,13 @@
|
|||
mod base;
|
||||
mod padding;
|
||||
|
||||
/// Loading is done in a continuous manner
|
||||
pub mod continuous;
|
||||
/// Loading is done in a continuous manner. lhs is transposed
|
||||
pub mod continuous_vectorized;
|
||||
/// Loading is done in a tile manner
|
||||
pub mod tile;
|
||||
/// Loading is done in a tile manner. lhs is transposed
|
||||
pub mod tile_vectorized;
|
||||
|
||||
pub use tile_vectorized::*;
|
|
@ -0,0 +1,241 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{slice, slice_assign},
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
use super::base::empty_from_context;
|
||||
|
||||
/// Pads tensor with zeros to make tensor number of rows and columns
|
||||
/// divisible by some quantity.
|
||||
/// For instance tensor of shape [1000, 1000] with divisors 64 and 64
|
||||
/// will be padded to [1024, 1024] with the last 24 elements being zeros
|
||||
pub(super) fn pad_round<E: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
row_divisor: usize,
|
||||
col_divisor: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let row_modulo = tensor.shape.dims[D - 2] % row_divisor;
|
||||
let col_modulo = tensor.shape.dims[D - 1] % col_divisor;
|
||||
if row_modulo == 0 && col_modulo == 0 {
|
||||
return tensor;
|
||||
}
|
||||
let mut padded_shape = Vec::with_capacity(D);
|
||||
for i in 0..D - 2 {
|
||||
padded_shape.push(tensor.shape.dims[i]);
|
||||
}
|
||||
padded_shape.push(tensor.shape.dims[D - 2] - row_modulo + row_divisor);
|
||||
padded_shape.push(tensor.shape.dims[D - 1] - col_modulo + col_divisor);
|
||||
padding::<E, D>(tensor, padded_shape.into())
|
||||
}
|
||||
|
||||
/// Pads tensor by adding zeros when padded dim is larger than tensor dim
|
||||
fn padding<E: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
padded_shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let ranges = padded_shape
|
||||
.dims
|
||||
.iter()
|
||||
.map(|dim| 0..*dim)
|
||||
.collect::<Vec<Range<usize>>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
slice_assign::<E, D, D>(
|
||||
empty_from_context(tensor.context.clone(), &padded_shape),
|
||||
ranges,
|
||||
tensor,
|
||||
)
|
||||
}
|
||||
|
||||
/// Crops tensor by deleting values when cropped dim is smaller than tensor dim
|
||||
pub(super) fn crop<E: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
cropped_shape: Shape<D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let ranges = cropped_shape
|
||||
.dims
|
||||
.iter()
|
||||
.map(|dim| 0..*dim)
|
||||
.collect::<Vec<Range<usize>>>()
|
||||
.try_into()
|
||||
.unwrap();
|
||||
slice::<E, D, D>(tensor, ranges)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::TestTensor;
|
||||
|
||||
#[test]
|
||||
fn padding_already_round_should_have_same_shape() {
|
||||
let row = 10;
|
||||
let row_divisor = 5;
|
||||
let col = 12;
|
||||
let col_divisor = 3;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
let expected_shape = [row, col].into();
|
||||
|
||||
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
|
||||
|
||||
assert!(padded.shape == expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn padding_already_round_should_have_same_values() {
|
||||
let row = 10;
|
||||
let row_divisor = 5;
|
||||
let col = 12;
|
||||
let col_divisor = 3;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
|
||||
let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor);
|
||||
|
||||
let padded = TestTensor::from_primitive(padded);
|
||||
padded.into_data().assert_approx_eq(&tensor.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn padding_not_round_should_have_rounded_shape() {
|
||||
let row = 10;
|
||||
let row_divisor = 6;
|
||||
let col = 12;
|
||||
let col_divisor = 5;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
let expected_shape = [12, 15].into();
|
||||
|
||||
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
|
||||
|
||||
assert!(padded.shape == expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn padding_not_round_should_have_same_values() {
|
||||
let row = 10;
|
||||
let row_divisor = 6;
|
||||
let col = 12;
|
||||
let col_divisor = 5;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
|
||||
let padded = pad_round(tensor.clone().into_primitive(), row_divisor, col_divisor);
|
||||
|
||||
let padded = TestTensor::from_primitive(padded).to_data();
|
||||
let tensor = tensor.into_data();
|
||||
for i in 0..row {
|
||||
for j in 0..col {
|
||||
assert!(padded.value[i * 15 + j] == tensor.value[i * col + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn padding_not_round_should_have_zero_padding() {
|
||||
let row = 10;
|
||||
let row_divisor = 6;
|
||||
let col = 12;
|
||||
let col_divisor = 5;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
|
||||
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
|
||||
let padded = TestTensor::from_primitive(padded).to_data();
|
||||
|
||||
// check right of matrix
|
||||
for i in 0..row {
|
||||
for j in col..15 {
|
||||
assert!(padded.value[i * 15 + j] == 0.0);
|
||||
}
|
||||
}
|
||||
// check below matrix, including bottom right
|
||||
for i in row..12 {
|
||||
for j in 0..15 {
|
||||
assert!(padded.value[i * 15 + j] == 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn padding_works_with_batch() {
|
||||
let row = 10;
|
||||
let row_divisor = 4;
|
||||
let col = 12;
|
||||
let col_divisor = 5;
|
||||
let tensor = TestTensor::random([2, 3, row, col], burn_tensor::Distribution::Default);
|
||||
let expected_shape = [2, 3, 12, 15].into();
|
||||
|
||||
let padded = pad_round(tensor.into_primitive(), row_divisor, col_divisor);
|
||||
|
||||
assert!(padded.shape == expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crop_same_shape_should_be_unchanged_shape() {
|
||||
let row = 10;
|
||||
let col = 12;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
let expected_shape = [row, col].into();
|
||||
|
||||
let unpadded = crop(tensor.into_primitive(), [row, col].into());
|
||||
|
||||
assert!(unpadded.shape == expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crop_same_shape_should_have_unchanged_values() {
|
||||
let row = 10;
|
||||
let col = 12;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
|
||||
let unpadded = crop(tensor.clone().into_primitive(), [row, col].into());
|
||||
|
||||
let unpadded = TestTensor::from_primitive(unpadded).to_data();
|
||||
let tensor = tensor.into_data();
|
||||
for i in 0..row {
|
||||
for j in 0..col {
|
||||
assert!(unpadded.value[i * col + j] == tensor.value[i * col + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crop_should_decrease_shape() {
|
||||
let row = 10;
|
||||
let keep_rows = 8;
|
||||
let col = 12;
|
||||
let keep_cols = 10;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
let expected_shape = [keep_rows, keep_cols].into();
|
||||
|
||||
let unpadded = crop(tensor.into_primitive(), [keep_rows, keep_cols].into());
|
||||
|
||||
assert!(unpadded.shape == expected_shape);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn crop_should_keep_same_values() {
|
||||
let row = 4;
|
||||
let keep_rows = 3;
|
||||
let col = 4;
|
||||
let keep_cols = 3;
|
||||
let tensor = TestTensor::random([row, col], burn_tensor::Distribution::Default);
|
||||
|
||||
let unpadded = crop(
|
||||
tensor.clone().into_primitive(),
|
||||
[keep_rows, keep_cols].into(),
|
||||
);
|
||||
|
||||
let unpadded = TestTensor::from_primitive(unpadded).to_data();
|
||||
let tensor = tensor.into_data();
|
||||
println!("{:?}\n {:?}", unpadded, tensor);
|
||||
|
||||
for i in 0..keep_rows {
|
||||
for j in 0..keep_cols {
|
||||
assert!(unpadded.value[i * keep_cols + j] == tensor.value[i * col + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{KernelSettings, SourceTemplate, StaticKernel},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
use super::base::{matmul_tiling_2d_launch, register_template};
|
||||
|
||||
matmul_tile_2d!(
|
||||
MatmulTiling2DTile,
|
||||
"../../../template/matmul/blocktiling_2d/tile.wgsl"
|
||||
);
|
|
@ -0,0 +1,12 @@
|
|||
use super::base::{matmul_tiling_2d_launch, register_template};
|
||||
use crate::{
|
||||
element::WgpuElement,
|
||||
kernel::{KernelSettings, SourceTemplate, StaticKernel},
|
||||
matmul_tile_2d,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
matmul_tile_2d!(
|
||||
MatmulTiling2DTileVectorized,
|
||||
"../../../template/matmul/blocktiling_2d/tile_vectorized.wgsl"
|
||||
);
|
|
@ -0,0 +1,48 @@
|
|||
use crate::{element::WgpuElement, tensor::WgpuTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
pub(crate) fn shape_out<E: WgpuElement, const D: usize>(
|
||||
lhs: &WgpuTensor<E, D>,
|
||||
rhs: &WgpuTensor<E, D>,
|
||||
) -> Shape<D> {
|
||||
let mut shape_out = [0; D];
|
||||
lhs.shape
|
||||
.dims
|
||||
.iter()
|
||||
.zip(rhs.shape.dims.iter())
|
||||
.enumerate()
|
||||
.for_each(|(index, (dim_lhs, dim_rhs))| {
|
||||
shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
|
||||
});
|
||||
shape_out[D - 2] = lhs.shape.dims[D - 2];
|
||||
shape_out[D - 1] = rhs.shape.dims[D - 1];
|
||||
Shape::new(shape_out)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::tensor::WgpuTensor;
|
||||
use crate::tests::{ReferenceTensor, TestTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
pub(crate) fn same_as_reference<F, const D: usize, S>(func: F, shape_lhs: S, shape_rhs: S)
|
||||
where
|
||||
F: Fn(WgpuTensor<f32, D>, WgpuTensor<f32, D>) -> WgpuTensor<f32, D>,
|
||||
S: Into<Shape<D>>,
|
||||
{
|
||||
let x = ReferenceTensor::random(shape_lhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
let y = ReferenceTensor::random(shape_rhs, burn_tensor::Distribution::Uniform(-1.0, 1.0));
|
||||
|
||||
let x_wgpu = TestTensor::from_data(x.to_data());
|
||||
let y_wgpu = TestTensor::from_data(y.to_data());
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
|
||||
let z = func(x_wgpu.into_primitive(), y_wgpu.into_primitive());
|
||||
let z = TestTensor::from_primitive(z);
|
||||
|
||||
std::println!("{z}");
|
||||
std::println!("{z_reference}");
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -4,7 +4,6 @@ mod cat;
|
|||
mod comparison;
|
||||
mod index;
|
||||
mod mask;
|
||||
mod matmul;
|
||||
mod reduction;
|
||||
mod source;
|
||||
mod unary;
|
||||
|
@ -12,11 +11,13 @@ mod unary_scalar;
|
|||
|
||||
pub use base::*;
|
||||
pub use binary_elemwise::*;
|
||||
pub use matmul::*;
|
||||
pub use source::*;
|
||||
pub use unary::*;
|
||||
pub use unary_scalar::*;
|
||||
|
||||
/// Matmul kernels
|
||||
pub mod matmul;
|
||||
|
||||
pub(crate) use cat::*;
|
||||
pub(crate) use comparison::*;
|
||||
pub(crate) use index::*;
|
||||
|
|
|
@ -38,7 +38,9 @@ mod tests {
|
|||
|
||||
pub type TestBackend = WgpuBackend<GraphicsApi, f32, i32>;
|
||||
pub type ReferenceBackend = burn_ndarray::NdArrayBackend<f32>;
|
||||
|
||||
pub type TestTensor<const D: usize> = burn_tensor::Tensor<TestBackend, D>;
|
||||
pub type ReferenceTensor<const D: usize> = burn_tensor::Tensor<ReferenceBackend, D>;
|
||||
pub type TestTensorInt<const D: usize> = burn_tensor::Tensor<TestBackend, D, burn_tensor::Int>;
|
||||
|
||||
burn_tensor::testgen_add!();
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use super::{numeric, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
|
||||
use crate::kernel::{
|
||||
self, matmul_tiling_2d_default, unary_default, unary_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::unary_scalar_inplace;
|
||||
use crate::{
|
||||
|
@ -140,7 +139,7 @@ where
|
|||
let lhs = kernel::into_continuous(lhs);
|
||||
let rhs = kernel::into_continuous(rhs);
|
||||
|
||||
matmul_tiling_2d_default::<FloatElem<Self>, D>(lhs, rhs)
|
||||
kernel::matmul::tile_vectorized::matmul_tiling_2d_default(lhs, rhs)
|
||||
}
|
||||
|
||||
fn swap_dims<const D: usize>(
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K = {{bm_x_bk}}u;
|
||||
const B_K_X_B_N = {{bk_x_bn}}u;
|
||||
const T_M = {{t_m}}u;
|
||||
const T_N = {{t_n}}u;
|
||||
const T_M_X_T_N = {{tm_x_tn}}u;
|
||||
|
||||
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
|
||||
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * K;
|
||||
var offset_rhs: u32 = skip_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: array<{{ elem }}, T_M>;
|
||||
var register_N: array<{{ elem }}, T_N>;
|
||||
|
||||
let thread_offset = local_idx * T_M_X_T_N;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
||||
let lhs_sm_position = thread_offset + load_index;
|
||||
let block_row = lhs_sm_position / B_K;
|
||||
let block_col = lhs_sm_position % B_K;
|
||||
let lhs_position = offset_lhs + k + block_row * K + block_col;
|
||||
|
||||
if block_row < B_M {
|
||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||
}
|
||||
}
|
||||
|
||||
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
||||
let rhs_sm_position = thread_offset + load_index;
|
||||
let block_row = rhs_sm_position / B_N;
|
||||
let block_col = rhs_sm_position % B_N;
|
||||
let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col;
|
||||
|
||||
if block_row < B_K {
|
||||
shared_rhs[rhs_sm_position] = rhs[rhs_position];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
// Load a subcolumn of values from lhs
|
||||
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
|
||||
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
|
||||
register_M[tile_index] = shared_lhs[lhs_sm_position];
|
||||
}
|
||||
// Load a subrow of values from rhs
|
||||
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
|
||||
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
|
||||
register_N[tile_index] = shared_rhs[rhs_sm_position];
|
||||
}
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let current_row = row + res_idx_M;
|
||||
let current_col = col + res_idx_N;
|
||||
// Check that we are within the bounds of output matrix
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + current_row * n_cols + current_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,138 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K = {{bm_x_bk}}u;
|
||||
const B_K_X_B_N = {{bk_x_bn}}u;
|
||||
const T_M = {{t_m}}u;
|
||||
const T_N = {{t_n}}u;
|
||||
const T_M_X_T_N = {{tm_x_tn}}u;
|
||||
|
||||
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
|
||||
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * K;
|
||||
var offset_rhs: u32 = skip_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: array<{{ elem }}, T_M>;
|
||||
var register_N: array<{{ elem }}, T_N>;
|
||||
|
||||
let thread_offset = local_idx * T_M_X_T_N;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
// tile: let lhs_sm_position = current_row * B_K + current_col;
|
||||
// tile_vec: let lhs_sm_position = current_row + current_col * B_M;
|
||||
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
||||
let lhs_sm_position = thread_offset + load_index;
|
||||
let block_row = lhs_sm_position % B_M;
|
||||
let block_col = lhs_sm_position / B_M;
|
||||
let lhs_position = offset_lhs + k + block_row * K + block_col;
|
||||
|
||||
if block_col < B_K {
|
||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||
}
|
||||
}
|
||||
|
||||
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
||||
let rhs_sm_position = thread_offset + load_index;
|
||||
let block_row = rhs_sm_position / B_N;
|
||||
let block_col = rhs_sm_position % B_N;
|
||||
let rhs_position = offset_rhs + (k + block_row) * n_cols + block_col;
|
||||
|
||||
if block_row < B_K {
|
||||
shared_rhs[rhs_sm_position] = rhs[rhs_position];
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
// Load a subcolumn of values from lhs
|
||||
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
|
||||
let lhs_sm_position = thread_row + tile_index + dot_index * B_M;
|
||||
register_M[tile_index] = shared_lhs[lhs_sm_position];
|
||||
}
|
||||
// Load a subrow of values from rhs
|
||||
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
|
||||
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
|
||||
register_N[tile_index] = shared_rhs[rhs_sm_position];
|
||||
}
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -68,35 +68,31 @@ fn main(
|
|||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
// In case B_M % T_M != 0 or B_N % T_N != 0
|
||||
// A thread must not read out of its block
|
||||
let actual_T_M = min(B_M - thread_row, T_M);
|
||||
let actual_T_N = min(B_N - thread_col, T_N);
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: array<{{ elem }}, T_M>;
|
||||
var register_N: array<{{ elem }}, T_N>;
|
||||
|
||||
let thread_offset = local_idx * T_M_X_T_N;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
|
||||
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
|
||||
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
|
||||
let sm_limit = min(B_K, K - k);
|
||||
|
||||
// Load data into shared memories
|
||||
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
|
||||
for (var i = 0u; i < actual_T_M; i++) {
|
||||
for (var j = 0u; j < actual_T_N; j++) {
|
||||
for (var i = 0u; i < T_M; i++) {
|
||||
for (var j = 0u; j < T_N; j++) {
|
||||
let current_row = thread_row + i;
|
||||
let current_col = thread_col + j;
|
||||
|
||||
if current_col < sm_limit {
|
||||
if current_col < B_K {
|
||||
let lhs_sm_position = current_row * B_K + current_col;
|
||||
let lhs_position = offset_lhs + k + current_row * K + current_col;
|
||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||
}
|
||||
|
||||
if current_row < sm_limit {
|
||||
if current_row < B_K {
|
||||
let rhs_sm_position = current_row * B_N + current_col;
|
||||
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
|
||||
shared_rhs[rhs_sm_position] = rhs[rhs_position];
|
||||
|
@ -104,26 +100,27 @@ fn main(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < sm_limit; dot_index++) {
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
// Load a subcolumn of values from lhs
|
||||
for (var tile_index = 0u; tile_index < actual_T_M; tile_index++) {
|
||||
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
|
||||
let lhs_sm_position = (thread_row + tile_index) * B_K + dot_index;
|
||||
register_M[tile_index] = shared_lhs[lhs_sm_position];
|
||||
}
|
||||
// Load a subrow of values from rhs
|
||||
for (var tile_index = 0u; tile_index < actual_T_N; tile_index++) {
|
||||
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
|
||||
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
|
||||
register_N[tile_index] = shared_rhs[rhs_sm_position];
|
||||
}
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
|
||||
results[res_idx_M * actual_T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -133,16 +130,11 @@ fn main(
|
|||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < actual_T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < actual_T_N; res_idx_N++) {
|
||||
let current_row = row + res_idx_M;
|
||||
let current_col = col + res_idx_N;
|
||||
// Check that we are within the bounds of output matrix
|
||||
if current_row < n_rows && current_col < n_cols {
|
||||
let result_position = res_idx_M * actual_T_N + res_idx_N;
|
||||
let output_position = offset_output + current_row * n_cols + current_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,140 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K = {{bm_x_bk}}u;
|
||||
const B_K_X_B_N = {{bk_x_bn}}u;
|
||||
const T_M = {{t_m}}u;
|
||||
const T_N = {{t_n}}u;
|
||||
const T_M_X_T_N = {{tm_x_tn}}u;
|
||||
|
||||
var<workgroup> shared_lhs: array<{{ elem }}, B_M_X_B_K>;
|
||||
var<workgroup> shared_rhs: array<{{ elem }}, B_K_X_B_N>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * K;
|
||||
var offset_rhs: u32 = skip_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: array<{{ elem }}, T_M>;
|
||||
var register_N: array<{{ elem }}, T_N>;
|
||||
|
||||
let thread_offset = local_idx * T_M_X_T_N;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
// sm_limit ensures that although there are up to B_M x B_N writes to memory,
|
||||
// shared memories remain B_M x B_K (lhs) or B_K x B_N (rhs)
|
||||
// also ensures we do not read out of matrices if M % B_M != 0 or N % B_N != 0
|
||||
|
||||
// Load data into shared memories
|
||||
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
|
||||
for (var i = 0u; i < T_M; i++) {
|
||||
for (var j = 0u; j < T_N; j++) {
|
||||
let current_row = thread_row + i;
|
||||
let current_col = thread_col + j;
|
||||
|
||||
if current_col < B_K {
|
||||
let lhs_sm_position = current_row + current_col * B_M;
|
||||
let lhs_position = offset_lhs + k + current_row * K + current_col;
|
||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||
}
|
||||
|
||||
if current_row < B_K {
|
||||
let rhs_sm_position = current_row * B_N + current_col;
|
||||
let rhs_position = offset_rhs + (k + current_row) * n_cols + current_col;
|
||||
shared_rhs[rhs_sm_position] = rhs[rhs_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
// Load a subcolumn of values from lhs
|
||||
for (var tile_index = 0u; tile_index < T_M; tile_index++) {
|
||||
let lhs_sm_position = thread_row + tile_index + dot_index * B_M;
|
||||
register_M[tile_index] = shared_lhs[lhs_sm_position];
|
||||
}
|
||||
// Load a subrow of values from rhs
|
||||
for (var tile_index = 0u; tile_index < T_N; tile_index++) {
|
||||
let rhs_sm_position = thread_col + tile_index + dot_index * B_N;
|
||||
register_N[tile_index] = shared_rhs[rhs_sm_position];
|
||||
}
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + (row + res_idx_M) * n_cols + col + res_idx_N;;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue