mirror of https://github.com/tracel-ai/burn.git
Perf/wgpu/matmul unpadded (#922)
This commit is contained in:
parent
64e58b4463
commit
35df31f700
|
@ -2,6 +2,7 @@ use burn_common::benchmark::{run_benchmark, Benchmark};
|
|||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Distribution, Shape, Tensor};
|
||||
use burn_wgpu::kernel::matmul::init_matmul_output;
|
||||
use burn_wgpu::kernel::matmul::unpadded::matmul_tiling_2d_unpadded;
|
||||
use burn_wgpu::kernel::matmul::vec4::matmul_tiling_2d_vec4;
|
||||
use burn_wgpu::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs;
|
||||
use burn_wgpu::WgpuDevice;
|
||||
|
@ -100,6 +101,11 @@ bench_matmul!(
|
|||
Tiling2DMatmulVec4,
|
||||
matmul_tiling_2d_vec4
|
||||
);
|
||||
bench_matmul!(
|
||||
Tiling2DMatmulUnpaddedBenchmark,
|
||||
Tiling2DMatmulUnpadded,
|
||||
matmul_tiling_2d_unpadded
|
||||
);
|
||||
|
||||
#[allow(dead_code)]
|
||||
/// Runs the benchmarks for wgpu matmul implementations
|
||||
|
@ -107,9 +113,9 @@ pub fn bench(device: &WgpuDevice) {
|
|||
const D: usize = 3;
|
||||
let num_repeats = 3;
|
||||
let batch_size = 3;
|
||||
let m = 2048;
|
||||
let k = 2048;
|
||||
let n = 1024;
|
||||
let m = 1007;
|
||||
let k = 1023;
|
||||
let n = 1005;
|
||||
let shape_lhs = Shape::new([batch_size, m, k]);
|
||||
let shape_rhs = Shape::new([batch_size, k, n]);
|
||||
|
||||
|
@ -125,6 +131,7 @@ pub fn bench(device: &WgpuDevice) {
|
|||
}
|
||||
run_matmul_benchmark!(NaiveMatmulBenchmark);
|
||||
run_matmul_benchmark!(MemCoalescingMatmulBenchmark);
|
||||
run_matmul_benchmark!(Tiling2DMatmulUnpaddedBenchmark);
|
||||
run_matmul_benchmark!(Tiling2DMatmulVec4LHSBenchmark);
|
||||
run_matmul_benchmark!(Tiling2DMatmulVec4Benchmark);
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ pub(crate) const B_N: usize = 64;
|
|||
pub(crate) const B_K: usize = 32;
|
||||
pub(crate) const WORKGROUP_SIZE: usize = 16;
|
||||
|
||||
pub(super) fn make_workgroup<const D: usize>(output_shape: Shape<D>) -> WorkGroup {
|
||||
pub(super) fn make_workgroup<const D: usize>(output_shape: &Shape<D>) -> WorkGroup {
|
||||
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32;
|
||||
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32;
|
||||
let mut num_blocks_z = 1;
|
||||
|
@ -71,7 +71,7 @@ pub(super) fn matmul_tiling_2d_launch<
|
|||
rounded_output_shape.clone(),
|
||||
);
|
||||
|
||||
let workgroup = make_workgroup(rounded_output_shape);
|
||||
let workgroup = make_workgroup(&rounded_output_shape);
|
||||
let info_handle = make_info_handle(&lhs, &rhs, &rounded_output);
|
||||
|
||||
lhs.client.execute(
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
mod base;
|
||||
mod padding;
|
||||
|
||||
/// WGSL vec4 primitives are used on left and right hand tensor,
|
||||
/// padding is avoided through the use of conditions in the kernel
|
||||
pub mod unpadded;
|
||||
/// WGSL vec4 primitives are used on left and right hand tensor
|
||||
pub mod vec4;
|
||||
/// WGSL vec4 primitives are used on left hand tensor
|
||||
|
|
|
@ -0,0 +1,195 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
use crate::{
|
||||
compute::DynamicKernel,
|
||||
element::WgpuElement,
|
||||
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::kernel_wgsl;
|
||||
|
||||
use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE};
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulTiling2DUnpaddedRaw,
|
||||
"../../../template/matmul/blocktiling_2d/unpadded.wgsl"
|
||||
);
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2DUnpadded<E: WgpuElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2DUnpadded<E> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
MatmulTiling2DUnpaddedRaw::source()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
|
||||
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
|
||||
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_z", "1".to_string())
|
||||
.register("elem", E::type_name())
|
||||
.register("int", "i32")
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
std::format!("{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with
|
||||
/// vec4 primitive on both lhs and rhs, with no padding needed
|
||||
pub fn matmul_tiling_2d_unpadded<E: WgpuElement + Element, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
out: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(lhs),
|
||||
false => lhs,
|
||||
};
|
||||
let rhs = match rhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(rhs),
|
||||
false => rhs,
|
||||
};
|
||||
|
||||
let workgroup = make_workgroup(&out.shape);
|
||||
let info_handle = make_info_handle(&lhs, &rhs, &out);
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(DynamicKernel::new(
|
||||
MatmulTiling2DUnpadded::<E>::new(),
|
||||
workgroup,
|
||||
)),
|
||||
&[&lhs.handle, &rhs.handle, &out.handle, &info_handle],
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::kernel::matmul::utils::tests::{same_as_reference, same_as_reference_swapped_dims};
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_straightforward() {
|
||||
test_with_params(1, 2, 1, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_shapes_smaller_than_blocks() {
|
||||
test_with_params(8, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_shapes_equal_blocks() {
|
||||
test_with_params(64, 32, 64, 2, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_m_exceeds_block() {
|
||||
test_with_params(75, 32, 64, 2, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_k_exceeds_block() {
|
||||
test_with_params(64, 33, 32, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_irregular_shape() {
|
||||
test_with_params(123, 255, 72, 3, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test64_matmul_unpadded_n_exceeds_block() {
|
||||
test_with_params(64, 32, 75, 2, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_n_smaller_than_m() {
|
||||
test_with_params(8, 8, 3, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_m_smaller_than_n() {
|
||||
test_with_params(3, 8, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_k_smaller_than_m_n() {
|
||||
test_with_params(8, 3, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_k_larger_than_m_n() {
|
||||
test_with_params(8, 48, 8, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_multibatch_1_dim() {
|
||||
test_with_params(8, 8, 8, 3, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_multibatch_2_dims() {
|
||||
test_with_params(8, 8, 8, 3, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_blocks_divide_shapes_unevenly() {
|
||||
test_with_params(7, 7, 7, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_medium() {
|
||||
test_with_params(17, 16, 16, 1, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_matmul_unpadded_large() {
|
||||
test_with_params(134, 242, 250, 1, 1);
|
||||
}
|
||||
|
||||
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
|
||||
let func = matmul_tiling_2d_unpadded;
|
||||
let shape_lhs = [batch_1, batch_2, m, k];
|
||||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(func, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_tiling_2d_primitive_swapped_batches_no_padding() {
|
||||
let matmul_func = matmul_tiling_2d_unpadded;
|
||||
let swap = [0, 1];
|
||||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(matmul_func, swap, swap, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_tiling_2d_primitive_swapped_row_col_no_padding() {
|
||||
let matmul_func = matmul_tiling_2d_unpadded;
|
||||
let swap_lhs = [0, 0];
|
||||
let swap_rhs = [2, 3];
|
||||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matmul_tiling_2d_primitive_swapped_row_with_batch_no_padding() {
|
||||
let matmul_func = matmul_tiling_2d_unpadded;
|
||||
let swap_lhs = [0, 3];
|
||||
let swap_rhs = [0, 2];
|
||||
let shape_lhs = [4, 4, 4, 4];
|
||||
let shape_rhs = [4, 4, 4, 4];
|
||||
same_as_reference_swapped_dims(matmul_func, swap_lhs, swap_rhs, shape_lhs, shape_rhs);
|
||||
}
|
||||
}
|
|
@ -12,18 +12,18 @@ use crate::kernel_wgsl;
|
|||
use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE};
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulTiling2Dvec4RHSRaw,
|
||||
MatmulTiling2Dvec4Raw,
|
||||
"../../../template/matmul/blocktiling_2d/vec4.wgsl"
|
||||
);
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2Dvec4RHS<E: WgpuElement> {
|
||||
struct MatmulTiling2Dvec4<E: WgpuElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2Dvec4RHS<E> {
|
||||
impl<E: WgpuElement> DynamicKernelSource for MatmulTiling2Dvec4<E> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
MatmulTiling2Dvec4RHSRaw::source()
|
||||
MatmulTiling2Dvec4Raw::source()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
|
@ -48,7 +48,7 @@ pub fn matmul_tiling_2d_vec4<E: WgpuElement + Element, const D: usize>(
|
|||
rhs: WgpuTensor<E, D>,
|
||||
out: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let kernel = MatmulTiling2Dvec4RHS::<E>::new();
|
||||
let kernel = MatmulTiling2Dvec4::<E>::new();
|
||||
matmul_tiling_2d_launch(lhs, rhs, out, kernel)
|
||||
}
|
||||
|
||||
|
|
|
@ -82,6 +82,11 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
|
|||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4TilingMatmulUnpaddedDefault::<E, 3>::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, 3>::new(lhs, rhs, out)),
|
||||
]
|
||||
}
|
||||
|
@ -97,7 +102,10 @@ impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
|
|||
2 => Box::new(Vec4TilingMatmulDefault::<E, D>::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
3 => Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
|
||||
3 => Box::new(Vec4TilingMatmulUnpaddedDefault::<E, D>::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
4 => Box::new(Vec4LhsOnlyTilingMatmulDefault::<E, D>::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
_ => panic!("Fastest index is out of bound"),
|
||||
|
@ -162,18 +170,24 @@ matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| {
|
|||
crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16)
|
||||
});
|
||||
|
||||
// Probably the fastest on MacOS.
|
||||
// Maybe the fastest on MacOS.
|
||||
matmul_tune_ops!(
|
||||
Vec4LhsOnlyTilingMatmulDefault,
|
||||
crate::kernel::matmul::vec4_lhs::matmul_tiling_2d_vec4_lhs
|
||||
);
|
||||
|
||||
// Probably the fastest.
|
||||
// Probably the fastest when fixed sizes.
|
||||
matmul_tune_ops!(
|
||||
Vec4TilingMatmulDefault,
|
||||
crate::kernel::matmul::vec4::matmul_tiling_2d_vec4
|
||||
);
|
||||
|
||||
// Probably the fastest otherwise.
|
||||
matmul_tune_ops!(
|
||||
Vec4TilingMatmulUnpaddedDefault,
|
||||
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
@ -0,0 +1,243 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const B_M = {{b_m}}u;
|
||||
const B_N = {{b_n}}u;
|
||||
const B_K = {{b_k}}u;
|
||||
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
|
||||
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
|
||||
|
||||
const T_M = 4u;
|
||||
const T_N = 4u;
|
||||
const T_M_X_T_N = 16u;
|
||||
|
||||
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
|
||||
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
) {
|
||||
let skip_row = workgroup_id.x * B_M;
|
||||
let skip_col = workgroup_id.y * B_N;
|
||||
|
||||
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
|
||||
|
||||
// Position of the first element of the thread, relative to the block
|
||||
let thread_row = (local_idx / n_thread_per_row) * T_M;
|
||||
let thread_col = (local_idx % n_thread_per_row) * T_N;
|
||||
|
||||
// Position of the first element of the thread, in absolute (in one batch)
|
||||
let row = skip_row + thread_row;
|
||||
let col = skip_col + thread_col;
|
||||
|
||||
let batch = global_id.z;
|
||||
|
||||
// Basic information
|
||||
let dim = info[0];
|
||||
let n_rows = info[6u * dim - 1u];
|
||||
let n_cols = info[6u * dim];
|
||||
let K = info[5u * dim - 1u];
|
||||
|
||||
// Row / col strides
|
||||
let lhs_stride_row = info[dim - 1u];
|
||||
let lhs_stride_col = info[dim];
|
||||
let rhs_stride_row = info[2u * dim - 1u];
|
||||
let rhs_stride_col = info[2u * dim];
|
||||
let out_stride_row = info [3u * dim - 1u];
|
||||
let out_stride_col = info [3u * dim];
|
||||
|
||||
// Calculate the corresponding offsets with support for broadcasting.
|
||||
let offset_output = batch * n_rows * n_cols;
|
||||
var offset_lhs: u32 = skip_row * lhs_stride_row;
|
||||
var offset_rhs: u32 = skip_col * rhs_stride_col;
|
||||
|
||||
let batch_dims = dim - 2u;
|
||||
for (var b: u32 = 1u; b <= batch_dims; b++) {
|
||||
let stride_lhs = info[b];
|
||||
let stride_rhs = info[b + dim];
|
||||
let stride_output = info[b + 2u * dim];
|
||||
let shape_lhs = info[b + 3u * dim];
|
||||
let shape_rhs = info[b + 4u * dim];
|
||||
|
||||
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
|
||||
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
|
||||
}
|
||||
|
||||
// Registers used in the compute pass
|
||||
var results: array<{{ elem }}, T_M_X_T_N>;
|
||||
var register_M: vec4<{{ elem }}>;
|
||||
var register_N: vec4<{{ elem }}>;
|
||||
|
||||
// How close is the thread to the end of the matrix.
|
||||
// If < 4 then it is an edge case
|
||||
let remain_row_lhs = n_rows - row;
|
||||
let remain_col_rhs = n_cols - col;
|
||||
|
||||
for (var k = 0u; k < K; k += B_K) {
|
||||
|
||||
// LHS LOAD PASS
|
||||
|
||||
// For the 4 vec4 columns of this thread
|
||||
for (var j = 0u; j < 4u; j++) {
|
||||
|
||||
// The precise
|
||||
let current_col = thread_col + j;
|
||||
|
||||
// Position of the column vec4 in shared memory
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
|
||||
|
||||
// To avoid overwriting following row in share memory
|
||||
if current_col < B_K {
|
||||
// To pad with zeros if outside lhs
|
||||
if current_col + k < K && remain_row_lhs >= 1u {
|
||||
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
|
||||
let lhs_position1 = lhs_position0 + lhs_stride_row;
|
||||
let lhs_position2 = lhs_position1 + lhs_stride_row;
|
||||
let lhs_position3 = lhs_position2 + lhs_stride_row;
|
||||
|
||||
if remain_row_lhs >= 4u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
lhs[lhs_position2],
|
||||
lhs[lhs_position3],
|
||||
);
|
||||
} else if remain_row_lhs == 3u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
lhs[lhs_position2],
|
||||
0.
|
||||
);
|
||||
} else if remain_row_lhs == 2u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
lhs[lhs_position1],
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
} else if remain_row_lhs == 1u {
|
||||
shared_lhs[lhs_sm_position] = vec4(
|
||||
lhs[lhs_position0],
|
||||
0.,
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RHS LOAD PASS
|
||||
|
||||
for (var i = 0u; i < 4u; i++) {
|
||||
let current_row = thread_row + i;
|
||||
|
||||
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
|
||||
|
||||
if current_row < B_K {
|
||||
if current_row + k < K && remain_col_rhs >= 1u {
|
||||
|
||||
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
|
||||
let rhs_position1 = rhs_position0 + rhs_stride_col;
|
||||
let rhs_position2 = rhs_position1 + rhs_stride_col;
|
||||
let rhs_position3 = rhs_position2 + rhs_stride_col;
|
||||
|
||||
if remain_col_rhs >= 4u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
rhs[rhs_position2],
|
||||
rhs[rhs_position3],
|
||||
);
|
||||
} else if remain_col_rhs == 3u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
rhs[rhs_position2],
|
||||
0.
|
||||
);
|
||||
} else if remain_col_rhs == 2u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
rhs[rhs_position1],
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
} else if remain_col_rhs == 1u {
|
||||
shared_rhs[rhs_sm_position] = vec4(
|
||||
rhs[rhs_position0],
|
||||
0.,
|
||||
0.,
|
||||
0.
|
||||
);
|
||||
}
|
||||
} else {
|
||||
shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
// COMPUTE PASS
|
||||
|
||||
// Compute intermediate results
|
||||
// Results are cumulated in results array and updated at each block
|
||||
// Outer loop indicates which subcolumns/subrows to read from shared memories
|
||||
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
|
||||
|
||||
// Load a subcolumn of values from lhs
|
||||
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
|
||||
register_M = shared_lhs[lhs_sm_position];
|
||||
|
||||
// Load a subrow of values from rhs
|
||||
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
|
||||
register_N = shared_rhs[rhs_sm_position];
|
||||
|
||||
// Multiply subcolumn and subrow and store results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
workgroupBarrier();
|
||||
}
|
||||
|
||||
// OUTPUT PASS
|
||||
|
||||
// Write output matrix
|
||||
// Each thread is responsible of writing T_M x T_N results
|
||||
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
|
||||
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
|
||||
let row_index = row + res_idx_M;
|
||||
let col_index = col + res_idx_N;
|
||||
if row_index < n_rows && col_index < n_cols {
|
||||
let result_position = res_idx_M * T_N + res_idx_N;
|
||||
let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col;
|
||||
output[output_position] = results[result_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue