offset check bounds work

This commit is contained in:
louisfd 2024-07-04 13:42:26 -04:00
parent cbb4de7156
commit 43877da1f2
5 changed files with 393 additions and 82 deletions

View File

@ -23,7 +23,7 @@ fn tiling2d_cube<F: Float>(
out: &mut Tensor<F>,
config: Comptime<CubeTiling2dConfig>,
) {
let info = get_info::<F>(lhs, rhs, out);
let dims = get_dims::<F>(lhs, rhs);
let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config);
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z);
let shared_memories = make_shared_memories::<F>(config);
@ -35,20 +35,17 @@ fn tiling2d_cube<F: Float>(
offsets,
shared_memories,
config,
info,
dims,
);
}
#[derive(CubeType, Copy, Clone)]
/// Information available at runtime only
/// Strides assume contiguous
pub(crate) struct CubeTiling2dInfo {
pub dim_m: UInt,
pub dim_k: UInt,
pub dim_n: UInt,
pub lhs_stride: UInt,
pub rhs_stride: UInt,
pub out_stride: UInt,
pub(crate) struct Dimensions {
pub m: UInt,
pub k: UInt,
pub n: UInt,
}
#[derive(CubeType, Copy, Clone)]
@ -75,25 +72,15 @@ pub(crate) struct Coordinates {
}
#[cube]
fn get_info<F: Float>(lhs: &Tensor<F>, rhs: &Tensor<F>, out: &Tensor<F>) -> CubeTiling2dInfo {
fn get_dims<F: Float>(lhs: &Tensor<F>, rhs: &Tensor<F>) -> Dimensions {
let rank = lhs.rank();
let first_dim = rank - UInt::new(2);
let second_dim = rank - UInt::new(1);
let dim_m = lhs.shape(first_dim);
let dim_k = lhs.shape(second_dim);
let dim_n = rhs.shape(second_dim);
let lhs_stride = lhs.stride(first_dim);
let rhs_stride = rhs.stride(first_dim);
let out_stride = out.stride(first_dim);
let m = lhs.shape(first_dim);
let k = lhs.shape(second_dim);
let n = rhs.shape(second_dim);
CubeTiling2dInfo {
dim_m,
dim_k,
dim_n,
lhs_stride,
rhs_stride,
out_stride,
}
Dimensions { m, k, n }
}
#[cube]

View File

@ -3,7 +3,7 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
compute_loop::{compute_loop, compute_loop_expand},
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
write_output::{write_to_output, write_to_output_expand},
@ -18,7 +18,7 @@ pub(crate) fn block_loop<F: Float>(
offsets: BatchOffsets,
shared: SharedMemories<F>,
config: Comptime<CubeTiling2dConfig>,
info: CubeTiling2dInfo,
dims: Dimensions,
) {
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let mut results = init_results::<F>(config);
@ -28,7 +28,7 @@ pub(crate) fn block_loop<F: Float>(
for k in range(0u32, n_loops, Comptime::new(false)) {
let k = k * Comptime::runtime(block_size_k);
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, info);
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, dims);
sync_units();
@ -37,14 +37,7 @@ pub(crate) fn block_loop<F: Float>(
sync_units();
}
write_to_output::<F>(
out,
&results,
coordinates,
offsets.out,
info.out_stride,
config,
);
write_to_output::<F>(out, &results, coordinates, offsets.out, dims.n, config);
}
#[cube]

View File

@ -3,7 +3,7 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
tile_read::{read_tile_from_global_memory, read_tile_from_global_memory_expand},
tile_write::{
write_tile_plain, write_tile_plain_expand, write_tile_transposed,
@ -20,23 +20,23 @@ pub(crate) fn load_to_shared_memories<F: Float>(
offsets: BatchOffsets,
shared: SharedMemories<F>,
config: Comptime<CubeTiling2dConfig>,
info: CubeTiling2dInfo,
dims: Dimensions,
) {
let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed);
let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed);
// Lhs must be loaded as transposed. If it already is in global memory, we load as plain.
// Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain.
if Comptime::get(lhs_transposed) {
// load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims);
} else {
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims);
}
// Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back.
if Comptime::get(rhs_transposed) {
// load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims);
} else {
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims);
}
}
@ -48,15 +48,15 @@ pub(crate) fn load_lhs_transposed<F: Float>(
batch_offset: UInt,
shared_lhs: SharedMemory<F>,
config: Comptime<CubeTiling2dConfig>,
info: CubeTiling2dInfo,
dims: Dimensions,
) {
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let tile_size = Comptime::map(config, |c| c.tile_size);
let sm_stride = Comptime::runtime(block_size_m);
let cube_offset = coordinates.skip_row * info.lhs_stride;
let offset = cube_offset + k + batch_offset;
let tensor_stride = dims.k;
let offset = coordinates.skip_row * tensor_stride + k + batch_offset;
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
@ -68,9 +68,9 @@ pub(crate) fn load_lhs_transposed<F: Float>(
coordinates.unit_col,
coordinates.skip_row,
k,
info.lhs_stride,
info.dim_m,
info.dim_k,
tensor_stride,
dims.m,
dims.k,
Comptime::map(config, |c| c.check_m_bounds),
Comptime::map(config, |c| c.check_k_bounds),
config,
@ -86,6 +86,52 @@ pub(crate) fn load_lhs_transposed<F: Float>(
);
}
#[cube]
pub(crate) fn load_lhs_plain<F: Float>(
lhs: &Tensor<F>,
coordinates: Coordinates,
k: UInt,
batch_offset: UInt,
shared_lhs: SharedMemory<F>,
config: Comptime<CubeTiling2dConfig>,
dims: Dimensions,
) {
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let tile_size = Comptime::map(config, |c| c.tile_size);
let sm_stride = Comptime::runtime(block_size_m);
let tensor_stride = dims.m;
let offset = coordinates.skip_row + k * tensor_stride + batch_offset;
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
read_tile_from_global_memory::<F>(
lhs,
&mut tile,
offset,
coordinates.unit_row,
coordinates.unit_col,
k,
coordinates.skip_row,
tensor_stride,
dims.k,
dims.m,
Comptime::map(config, |c| c.check_k_bounds),
Comptime::map(config, |c| c.check_m_bounds),
config,
);
write_tile_plain::<F>(
&tile,
shared_lhs,
coordinates.unit_row,
coordinates.unit_col,
sm_stride,
config,
);
}
#[cube]
pub(crate) fn load_rhs_plain<F: Float>(
rhs: &Tensor<F>,
@ -94,14 +140,15 @@ pub(crate) fn load_rhs_plain<F: Float>(
batch_offset: UInt,
shared_rhs: SharedMemory<F>,
config: Comptime<CubeTiling2dConfig>,
info: CubeTiling2dInfo,
dims: Dimensions,
) {
let block_size_n = Comptime::map(config, |c| c.block_size_n);
let tile_size = Comptime::map(config, |c| c.tile_size);
let sm_stride = Comptime::runtime(block_size_n);
let offset = coordinates.skip_col + k * info.rhs_stride + batch_offset;
let tensor_stride = dims.n;
let offset = coordinates.skip_col + k * tensor_stride + batch_offset;
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
@ -113,9 +160,9 @@ pub(crate) fn load_rhs_plain<F: Float>(
coordinates.unit_col,
k,
coordinates.skip_col,
info.rhs_stride,
info.dim_k,
info.dim_n,
tensor_stride,
dims.k,
dims.n,
Comptime::map(config, |c| c.check_k_bounds),
Comptime::map(config, |c| c.check_n_bounds),
config,
@ -131,6 +178,52 @@ pub(crate) fn load_rhs_plain<F: Float>(
);
}
#[cube]
pub(crate) fn load_rhs_transposed<F: Float>(
rhs: &Tensor<F>,
coordinates: Coordinates,
k: UInt,
batch_offset: UInt,
shared_rhs: SharedMemory<F>,
config: Comptime<CubeTiling2dConfig>,
dims: Dimensions,
) {
let block_size_n = Comptime::map(config, |c| c.block_size_n);
let tile_size = Comptime::map(config, |c| c.tile_size);
let sm_stride = Comptime::runtime(block_size_n);
let tensor_stride = dims.k;
let offset = coordinates.skip_col * tensor_stride + k + batch_offset;
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
read_tile_from_global_memory::<F>(
rhs,
&mut tile,
offset,
coordinates.unit_row,
coordinates.unit_col,
coordinates.skip_col,
k,
tensor_stride,
dims.n,
dims.k,
Comptime::map(config, |c| c.check_n_bounds),
Comptime::map(config, |c| c.check_k_bounds),
config,
);
write_tile_transposed::<F>(
&tile,
shared_rhs,
coordinates.unit_col,
coordinates.unit_row,
sm_stride,
config,
);
}
#[cfg(feature = "export_tests")]
/// Exported tests for loading to shared memory
pub mod tests {
@ -139,7 +232,7 @@ pub mod tests {
};
use crate::JitRuntime;
use super::{super::base::CoordinatesExpand, super::base::CubeTiling2dInfoExpand, *};
use super::{super::base::CoordinatesExpand, super::base::DimensionsExpand, *};
#[cube(launch)]
fn load_tensor_test<F: Float>(
@ -168,24 +261,18 @@ pub mod tests {
};
if Comptime::get(is_lhs) {
let info = CubeTiling2dInfo {
dim_m: tensor.shape(tensor.rank() - UInt::new(2)),
dim_k: tensor.shape(tensor.rank() - UInt::new(1)),
dim_n: UInt::new(0),
lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
rhs_stride: UInt::new(0),
out_stride: UInt::new(0),
let info = Dimensions {
m: tensor.shape(tensor.rank() - UInt::new(2)),
k: tensor.shape(tensor.rank() - UInt::new(1)),
n: UInt::new(0),
};
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info);
} else {
let info = CubeTiling2dInfo {
dim_m: UInt::new(0),
dim_k: tensor.shape(tensor.rank() - UInt::new(2)),
dim_n: tensor.shape(tensor.rank() - UInt::new(1)),
lhs_stride: UInt::new(0),
rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
out_stride: UInt::new(0),
let info = Dimensions {
m: UInt::new(0),
k: tensor.shape(tensor.rank() - UInt::new(2)),
n: tensor.shape(tensor.rank() - UInt::new(1)),
};
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info);
@ -196,6 +283,57 @@ pub mod tests {
}
}
#[cube(launch)]
fn load_tensor_permuted_test<F: Float>(
tensor: &Tensor<F>,
sm_out: &mut Array<F>,
unit_row: UInt,
unit_col: UInt,
k: UInt,
config: Comptime<CubeTiling2dConfig>,
is_lhs: Comptime<bool>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let sm_size = block_size_k * block_size_m / tile_size;
let shared_memory =
SharedMemory::<F>::vectorized(Comptime::get(sm_size), Comptime::get(tile_size));
let offset = UInt::new(0);
let coordinates = Coordinates {
unit_row,
unit_col,
skip_row: UInt::new(0),
skip_col: UInt::new(0),
};
if Comptime::get(is_lhs) {
// Permuted
let dims = Dimensions {
m: tensor.shape(tensor.rank() - UInt::new(1)),
k: tensor.shape(tensor.rank() - UInt::new(2)),
n: UInt::new(0),
};
load_lhs_plain(tensor, coordinates, k, offset, shared_memory, config, dims);
} else {
// Permuted
let dims = Dimensions {
m: UInt::new(0),
k: tensor.shape(tensor.rank() - UInt::new(1)),
n: tensor.shape(tensor.rank() - UInt::new(2)),
};
load_rhs_transposed(tensor, coordinates, k, offset, shared_memory, config, dims);
}
for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) {
sm_out[i] = shared_memory[i];
}
}
#[cube(launch)]
fn load_tensor_multiple_tiles_test<F: Float>(
tensor: &Tensor<F>,
@ -223,24 +361,18 @@ pub mod tests {
};
if Comptime::get(is_lhs) {
let info = CubeTiling2dInfo {
dim_m: tensor.shape(tensor.rank() - UInt::new(2)),
dim_k: tensor.shape(tensor.rank() - UInt::new(1)),
dim_n: UInt::new(0),
lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
rhs_stride: UInt::new(0),
out_stride: UInt::new(0),
let info = Dimensions {
m: tensor.shape(tensor.rank() - UInt::new(2)),
k: tensor.shape(tensor.rank() - UInt::new(1)),
n: UInt::new(0),
};
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info);
} else {
let info = CubeTiling2dInfo {
dim_m: UInt::new(0),
dim_k: tensor.shape(tensor.rank() - UInt::new(2)),
dim_n: tensor.shape(tensor.rank() - UInt::new(1)),
lhs_stride: UInt::new(0),
rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
out_stride: UInt::new(0),
let info = Dimensions {
m: UInt::new(0),
k: tensor.shape(tensor.rank() - UInt::new(2)),
n: tensor.shape(tensor.rank() - UInt::new(1)),
};
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info);
@ -468,4 +600,131 @@ pub mod tests {
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn load_lhs_plain_unit_test<R: JitRuntime>(device: &R::Device) {
let lhs = range_tensor::<R>(16, 16, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
4,
4,
8,
config,
true,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0,
0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0,
246.0, 247.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn load_lhs_plain_out_of_bounds_unit_test<R: JitRuntime>(device: &R::Device) {
let (m, k) = (6, 14);
let lhs = range_tensor::<R>(k, m, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(m, k, 8);
load_tensor_permuted_test_launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
4,
4,
8,
config,
true,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn load_rhs_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
let rhs = range_tensor::<R>(16, 16, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
4,
4,
8,
config,
false,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0,
0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn load_rhs_transposed_out_of_bounds_unit_test<R: JitRuntime>(device: &R::Device) {
let (k, n) = (14, 6);
let rhs = range_tensor::<R>(n, k, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, k, n);
load_tensor_permuted_test_launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
4,
4,
8,
config,
false,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
}

View File

@ -518,6 +518,54 @@ mod tests {
);
}
#[test]
fn swapped_lhs_row_col_large_uneven_m() {
let (m, k, n) = (252, 256, 256);
let swap_lhs = [2, 3];
let swap_rhs = [0, 0];
let shape_lhs = [3, 2, k, m];
let shape_rhs = [3, 2, k, n];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
swap_lhs,
swap_rhs,
shape_lhs,
shape_rhs,
);
}
#[test]
fn swapped_rhs_row_col_large_uneven_n() {
let (m, k, n) = (256, 256, 252);
let swap_lhs = [0, 0];
let swap_rhs = [2, 3];
let shape_lhs = [3, 2, m, k];
let shape_rhs = [3, 2, n, k];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
swap_lhs,
swap_rhs,
shape_lhs,
shape_rhs,
);
}
#[test]
fn swapped_both_row_col_large_uneven_k() {
let (m, k, n) = (256, 252, 256);
let swap_lhs = [2, 3];
let swap_rhs = [2, 3];
let shape_lhs = [3, 2, k, m];
let shape_rhs = [3, 2, n, k];
same_as_reference_swapped_dims(
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
swap_lhs,
swap_rhs,
shape_lhs,
shape_rhs,
);
}
#[test]
fn swapped_row_with_batch_no_padding() {
let swap_lhs = [0, 3];

View File

@ -99,6 +99,18 @@ mod tests {
load_shared_memory_tests::load_lhs_transposed_cube_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_lhs_plain_unit_test() {
load_shared_memory_tests::load_lhs_plain_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_lhs_plain_out_of_bounds_unit_test() {
load_shared_memory_tests::load_lhs_plain_out_of_bounds_unit_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn load_lhs_transposed_out_of_bounds_cube_test() {
load_shared_memory_tests::load_lhs_transposed_out_of_bounds_cube_test::<TestRuntime>(
@ -128,6 +140,18 @@ mod tests {
load_shared_memory_tests::load_rhs_plain_cube_offset_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_rhs_transposed_unit_test() {
load_shared_memory_tests::load_rhs_transposed_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn load_rhs_transposed_out_of_bounds_unit_test() {
load_shared_memory_tests::load_rhs_transposed_out_of_bounds_unit_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn write_results_inner_loop_unit_test() {
write_output_tests::write_results_inner_loop_unit_test::<TestRuntime>(&Default::default())