From 43877da1f26688ce7cb4e42360e6089e7f36ba49 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 4 Jul 2024 13:42:26 -0400 Subject: [PATCH] offset check bounds work --- .../src/kernel/matmul/tiling2d_cube/base.rs | 35 +- .../kernel/matmul/tiling2d_cube/block_loop.rs | 15 +- .../tiling2d_cube/load_shared_memory.rs | 353 +++++++++++++++--- crates/burn-jit/src/tests/matmul.rs | 48 +++ crates/burn-jit/src/tests/matmul_cube.rs | 24 ++ 5 files changed, 393 insertions(+), 82 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs index 80d6c32d2..8879bdcc6 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs @@ -23,7 +23,7 @@ fn tiling2d_cube( out: &mut Tensor, config: Comptime, ) { - let info = get_info::(lhs, rhs, out); + let dims = get_dims::(lhs, rhs); let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config); let offsets = calculate_batch_offsets::(lhs, rhs, out, CUBE_POS_Z); let shared_memories = make_shared_memories::(config); @@ -35,20 +35,17 @@ fn tiling2d_cube( offsets, shared_memories, config, - info, + dims, ); } #[derive(CubeType, Copy, Clone)] /// Information available at runtime only /// Strides assume contiguous -pub(crate) struct CubeTiling2dInfo { - pub dim_m: UInt, - pub dim_k: UInt, - pub dim_n: UInt, - pub lhs_stride: UInt, - pub rhs_stride: UInt, - pub out_stride: UInt, +pub(crate) struct Dimensions { + pub m: UInt, + pub k: UInt, + pub n: UInt, } #[derive(CubeType, Copy, Clone)] @@ -75,25 +72,15 @@ pub(crate) struct Coordinates { } #[cube] -fn get_info(lhs: &Tensor, rhs: &Tensor, out: &Tensor) -> CubeTiling2dInfo { +fn get_dims(lhs: &Tensor, rhs: &Tensor) -> Dimensions { let rank = lhs.rank(); let first_dim = rank - UInt::new(2); let second_dim = rank - UInt::new(1); - let dim_m = lhs.shape(first_dim); - let dim_k = lhs.shape(second_dim); - let dim_n = rhs.shape(second_dim); - let lhs_stride = lhs.stride(first_dim); - let rhs_stride = rhs.stride(first_dim); - let out_stride = out.stride(first_dim); + let m = lhs.shape(first_dim); + let k = lhs.shape(second_dim); + let n = rhs.shape(second_dim); - CubeTiling2dInfo { - dim_m, - dim_k, - dim_n, - lhs_stride, - rhs_stride, - out_stride, - } + Dimensions { m, k, n } } #[cube] diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs index b54dc7484..7dad105c7 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/block_loop.rs @@ -3,7 +3,7 @@ use burn_cube::prelude::*; use crate::kernel::matmul::config::CubeTiling2dConfig; use super::{ - base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories}, + base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, compute_loop::{compute_loop, compute_loop_expand}, load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand}, write_output::{write_to_output, write_to_output_expand}, @@ -18,7 +18,7 @@ pub(crate) fn block_loop( offsets: BatchOffsets, shared: SharedMemories, config: Comptime, - info: CubeTiling2dInfo, + dims: Dimensions, ) { let block_size_k = Comptime::map(config, |c| c.block_size_k); let mut results = init_results::(config); @@ -28,7 +28,7 @@ pub(crate) fn block_loop( for k in range(0u32, n_loops, Comptime::new(false)) { let k = k * Comptime::runtime(block_size_k); - load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, info); + load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, dims); sync_units(); @@ -37,14 +37,7 @@ pub(crate) fn block_loop( sync_units(); } - write_to_output::( - out, - &results, - coordinates, - offsets.out, - info.out_stride, - config, - ); + write_to_output::(out, &results, coordinates, offsets.out, dims.n, config); } #[cube] diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs index 184535473..eb5c666c7 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs @@ -3,7 +3,7 @@ use burn_cube::prelude::*; use crate::kernel::matmul::config::CubeTiling2dConfig; use super::{ - base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories}, + base::{BatchOffsets, Coordinates, Dimensions, SharedMemories}, tile_read::{read_tile_from_global_memory, read_tile_from_global_memory_expand}, tile_write::{ write_tile_plain, write_tile_plain_expand, write_tile_transposed, @@ -20,23 +20,23 @@ pub(crate) fn load_to_shared_memories( offsets: BatchOffsets, shared: SharedMemories, config: Comptime, - info: CubeTiling2dInfo, + dims: Dimensions, ) { let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed); let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed); - // Lhs must be loaded as transposed. If it already is in global memory, we load as plain. + // Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain. if Comptime::get(lhs_transposed) { - // load_lhs_plain::(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info); + load_lhs_plain::(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims); } else { - load_lhs_transposed::(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info); + load_lhs_transposed::(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims); } // Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back. if Comptime::get(rhs_transposed) { - // load_rhs_transposed::(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info); + load_rhs_transposed::(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims); } else { - load_rhs_plain::(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info); + load_rhs_plain::(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims); } } @@ -48,15 +48,15 @@ pub(crate) fn load_lhs_transposed( batch_offset: UInt, shared_lhs: SharedMemory, config: Comptime, - info: CubeTiling2dInfo, + dims: Dimensions, ) { let block_size_m = Comptime::map(config, |c| c.block_size_m); let tile_size = Comptime::map(config, |c| c.tile_size); let sm_stride = Comptime::runtime(block_size_m); - let cube_offset = coordinates.skip_row * info.lhs_stride; - let offset = cube_offset + k + batch_offset; + let tensor_stride = dims.k; + let offset = coordinates.skip_row * tensor_stride + k + batch_offset; let mut tile = Array::::vectorized(Comptime::get(tile_size), Comptime::get(tile_size)); @@ -68,9 +68,9 @@ pub(crate) fn load_lhs_transposed( coordinates.unit_col, coordinates.skip_row, k, - info.lhs_stride, - info.dim_m, - info.dim_k, + tensor_stride, + dims.m, + dims.k, Comptime::map(config, |c| c.check_m_bounds), Comptime::map(config, |c| c.check_k_bounds), config, @@ -86,6 +86,52 @@ pub(crate) fn load_lhs_transposed( ); } +#[cube] +pub(crate) fn load_lhs_plain( + lhs: &Tensor, + coordinates: Coordinates, + k: UInt, + batch_offset: UInt, + shared_lhs: SharedMemory, + config: Comptime, + dims: Dimensions, +) { + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let tile_size = Comptime::map(config, |c| c.tile_size); + + let sm_stride = Comptime::runtime(block_size_m); + + let tensor_stride = dims.m; + let offset = coordinates.skip_row + k * tensor_stride + batch_offset; + + let mut tile = Array::::vectorized(Comptime::get(tile_size), Comptime::get(tile_size)); + + read_tile_from_global_memory::( + lhs, + &mut tile, + offset, + coordinates.unit_row, + coordinates.unit_col, + k, + coordinates.skip_row, + tensor_stride, + dims.k, + dims.m, + Comptime::map(config, |c| c.check_k_bounds), + Comptime::map(config, |c| c.check_m_bounds), + config, + ); + + write_tile_plain::( + &tile, + shared_lhs, + coordinates.unit_row, + coordinates.unit_col, + sm_stride, + config, + ); +} + #[cube] pub(crate) fn load_rhs_plain( rhs: &Tensor, @@ -94,14 +140,15 @@ pub(crate) fn load_rhs_plain( batch_offset: UInt, shared_rhs: SharedMemory, config: Comptime, - info: CubeTiling2dInfo, + dims: Dimensions, ) { let block_size_n = Comptime::map(config, |c| c.block_size_n); let tile_size = Comptime::map(config, |c| c.tile_size); let sm_stride = Comptime::runtime(block_size_n); - let offset = coordinates.skip_col + k * info.rhs_stride + batch_offset; + let tensor_stride = dims.n; + let offset = coordinates.skip_col + k * tensor_stride + batch_offset; let mut tile = Array::::vectorized(Comptime::get(tile_size), Comptime::get(tile_size)); @@ -113,9 +160,9 @@ pub(crate) fn load_rhs_plain( coordinates.unit_col, k, coordinates.skip_col, - info.rhs_stride, - info.dim_k, - info.dim_n, + tensor_stride, + dims.k, + dims.n, Comptime::map(config, |c| c.check_k_bounds), Comptime::map(config, |c| c.check_n_bounds), config, @@ -131,6 +178,52 @@ pub(crate) fn load_rhs_plain( ); } +#[cube] +pub(crate) fn load_rhs_transposed( + rhs: &Tensor, + coordinates: Coordinates, + k: UInt, + batch_offset: UInt, + shared_rhs: SharedMemory, + config: Comptime, + dims: Dimensions, +) { + let block_size_n = Comptime::map(config, |c| c.block_size_n); + let tile_size = Comptime::map(config, |c| c.tile_size); + + let sm_stride = Comptime::runtime(block_size_n); + + let tensor_stride = dims.k; + let offset = coordinates.skip_col * tensor_stride + k + batch_offset; + + let mut tile = Array::::vectorized(Comptime::get(tile_size), Comptime::get(tile_size)); + + read_tile_from_global_memory::( + rhs, + &mut tile, + offset, + coordinates.unit_row, + coordinates.unit_col, + coordinates.skip_col, + k, + tensor_stride, + dims.n, + dims.k, + Comptime::map(config, |c| c.check_n_bounds), + Comptime::map(config, |c| c.check_k_bounds), + config, + ); + + write_tile_transposed::( + &tile, + shared_rhs, + coordinates.unit_col, + coordinates.unit_row, + sm_stride, + config, + ); +} + #[cfg(feature = "export_tests")] /// Exported tests for loading to shared memory pub mod tests { @@ -139,7 +232,7 @@ pub mod tests { }; use crate::JitRuntime; - use super::{super::base::CoordinatesExpand, super::base::CubeTiling2dInfoExpand, *}; + use super::{super::base::CoordinatesExpand, super::base::DimensionsExpand, *}; #[cube(launch)] fn load_tensor_test( @@ -168,24 +261,18 @@ pub mod tests { }; if Comptime::get(is_lhs) { - let info = CubeTiling2dInfo { - dim_m: tensor.shape(tensor.rank() - UInt::new(2)), - dim_k: tensor.shape(tensor.rank() - UInt::new(1)), - dim_n: UInt::new(0), - lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)), - rhs_stride: UInt::new(0), - out_stride: UInt::new(0), + let info = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(2)), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: UInt::new(0), }; load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info); } else { - let info = CubeTiling2dInfo { - dim_m: UInt::new(0), - dim_k: tensor.shape(tensor.rank() - UInt::new(2)), - dim_n: tensor.shape(tensor.rank() - UInt::new(1)), - lhs_stride: UInt::new(0), - rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)), - out_stride: UInt::new(0), + let info = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: tensor.shape(tensor.rank() - UInt::new(1)), }; load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info); @@ -196,6 +283,57 @@ pub mod tests { } } + #[cube(launch)] + fn load_tensor_permuted_test( + tensor: &Tensor, + sm_out: &mut Array, + unit_row: UInt, + unit_col: UInt, + k: UInt, + config: Comptime, + is_lhs: Comptime, + ) { + let tile_size = Comptime::map(config, |c| c.tile_size); + let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_m = Comptime::map(config, |c| c.block_size_m); + let sm_size = block_size_k * block_size_m / tile_size; + let shared_memory = + SharedMemory::::vectorized(Comptime::get(sm_size), Comptime::get(tile_size)); + + let offset = UInt::new(0); + + let coordinates = Coordinates { + unit_row, + unit_col, + skip_row: UInt::new(0), + skip_col: UInt::new(0), + }; + + if Comptime::get(is_lhs) { + // Permuted + let dims = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(1)), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: UInt::new(0), + }; + + load_lhs_plain(tensor, coordinates, k, offset, shared_memory, config, dims); + } else { + // Permuted + let dims = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: tensor.shape(tensor.rank() - UInt::new(2)), + }; + + load_rhs_transposed(tensor, coordinates, k, offset, shared_memory, config, dims); + } + + for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) { + sm_out[i] = shared_memory[i]; + } + } + #[cube(launch)] fn load_tensor_multiple_tiles_test( tensor: &Tensor, @@ -223,24 +361,18 @@ pub mod tests { }; if Comptime::get(is_lhs) { - let info = CubeTiling2dInfo { - dim_m: tensor.shape(tensor.rank() - UInt::new(2)), - dim_k: tensor.shape(tensor.rank() - UInt::new(1)), - dim_n: UInt::new(0), - lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)), - rhs_stride: UInt::new(0), - out_stride: UInt::new(0), + let info = Dimensions { + m: tensor.shape(tensor.rank() - UInt::new(2)), + k: tensor.shape(tensor.rank() - UInt::new(1)), + n: UInt::new(0), }; load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info); } else { - let info = CubeTiling2dInfo { - dim_m: UInt::new(0), - dim_k: tensor.shape(tensor.rank() - UInt::new(2)), - dim_n: tensor.shape(tensor.rank() - UInt::new(1)), - lhs_stride: UInt::new(0), - rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)), - out_stride: UInt::new(0), + let info = Dimensions { + m: UInt::new(0), + k: tensor.shape(tensor.rank() - UInt::new(2)), + n: tensor.shape(tensor.rank() - UInt::new(1)), }; load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info); @@ -468,4 +600,131 @@ pub mod tests { ]; assert_equals::(sm_out, expected, device); } + + /// Exported test + pub fn load_lhs_plain_unit_test(device: &R::Device) { + let lhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::new(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_permuted_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + 4, + 4, + 8, + config, + true, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0, + 0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0, + 246.0, 247.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_lhs_plain_out_of_bounds_unit_test(device: &R::Device) { + let (m, k) = (6, 14); + let lhs = range_tensor::(k, m, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::new(1, 1, 1); + + let config = make_config(m, k, 8); + + load_tensor_permuted_test_launch::( + lhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + 4, + 4, + 8, + config, + true, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_transposed_unit_test(device: &R::Device) { + let rhs = range_tensor::(16, 16, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::new(1, 1, 1); + + let config = make_config(16, 16, 8); + + load_tensor_permuted_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + 4, + 4, + 8, + config, + false, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0, + 0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0, + ]; + assert_equals::(sm_out, expected, device); + } + + /// Exported test + pub fn load_rhs_transposed_out_of_bounds_unit_test(device: &R::Device) { + let (k, n) = (14, 6); + let rhs = range_tensor::(n, k, device); + let sm_out = create_empty::(8, 8, device); + let cube_dim = CubeDim::new(1, 1, 1); + let cube_count = CubeCount::new(1, 1, 1); + + let config = make_config(8, k, n); + + load_tensor_permuted_test_launch::( + rhs.client.clone(), + cube_count, + cube_dim, + TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64), + 4, + 4, + 8, + config, + false, + ); + + let expected = &[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + ]; + assert_equals::(sm_out, expected, device); + } } diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index c559fd4af..c430476d6 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -518,6 +518,54 @@ mod tests { ); } + #[test] + fn swapped_lhs_row_col_large_uneven_m() { + let (m, k, n) = (252, 256, 256); + let swap_lhs = [2, 3]; + let swap_rhs = [0, 0]; + let shape_lhs = [3, 2, k, m]; + let shape_rhs = [3, 2, k, n]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_rhs_row_col_large_uneven_n() { + let (m, k, n) = (256, 256, 252); + let swap_lhs = [0, 0]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, m, k]; + let shape_rhs = [3, 2, n, k]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + + #[test] + fn swapped_both_row_col_large_uneven_k() { + let (m, k, n) = (256, 252, 256); + let swap_lhs = [2, 3]; + let swap_rhs = [2, 3]; + let shape_lhs = [3, 2, k, m]; + let shape_rhs = [3, 2, n, k]; + same_as_reference_swapped_dims( + MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())), + swap_lhs, + swap_rhs, + shape_lhs, + shape_rhs, + ); + } + #[test] fn swapped_row_with_batch_no_padding() { let swap_lhs = [0, 3]; diff --git a/crates/burn-jit/src/tests/matmul_cube.rs b/crates/burn-jit/src/tests/matmul_cube.rs index 74ef8e714..e61008e42 100644 --- a/crates/burn-jit/src/tests/matmul_cube.rs +++ b/crates/burn-jit/src/tests/matmul_cube.rs @@ -99,6 +99,18 @@ mod tests { load_shared_memory_tests::load_lhs_transposed_cube_test::(&Default::default()) } + #[test] + pub fn load_lhs_plain_unit_test() { + load_shared_memory_tests::load_lhs_plain_unit_test::(&Default::default()) + } + + #[test] + pub fn load_lhs_plain_out_of_bounds_unit_test() { + load_shared_memory_tests::load_lhs_plain_out_of_bounds_unit_test::( + &Default::default(), + ) + } + #[test] pub fn load_lhs_transposed_out_of_bounds_cube_test() { load_shared_memory_tests::load_lhs_transposed_out_of_bounds_cube_test::( @@ -128,6 +140,18 @@ mod tests { load_shared_memory_tests::load_rhs_plain_cube_offset_test::(&Default::default()) } + #[test] + pub fn load_rhs_transposed_unit_test() { + load_shared_memory_tests::load_rhs_transposed_unit_test::(&Default::default()) + } + + #[test] + pub fn load_rhs_transposed_out_of_bounds_unit_test() { + load_shared_memory_tests::load_rhs_transposed_out_of_bounds_unit_test::( + &Default::default(), + ) + } + #[test] pub fn write_results_inner_loop_unit_test() { write_output_tests::write_results_inner_loop_unit_test::(&Default::default())