mirror of https://github.com/tracel-ai/burn.git
offset check bounds work
This commit is contained in:
parent
cbb4de7156
commit
43877da1f2
|
@ -23,7 +23,7 @@ fn tiling2d_cube<F: Float>(
|
|||
out: &mut Tensor<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let info = get_info::<F>(lhs, rhs, out);
|
||||
let dims = get_dims::<F>(lhs, rhs);
|
||||
let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config);
|
||||
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z);
|
||||
let shared_memories = make_shared_memories::<F>(config);
|
||||
|
@ -35,20 +35,17 @@ fn tiling2d_cube<F: Float>(
|
|||
offsets,
|
||||
shared_memories,
|
||||
config,
|
||||
info,
|
||||
dims,
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(CubeType, Copy, Clone)]
|
||||
/// Information available at runtime only
|
||||
/// Strides assume contiguous
|
||||
pub(crate) struct CubeTiling2dInfo {
|
||||
pub dim_m: UInt,
|
||||
pub dim_k: UInt,
|
||||
pub dim_n: UInt,
|
||||
pub lhs_stride: UInt,
|
||||
pub rhs_stride: UInt,
|
||||
pub out_stride: UInt,
|
||||
pub(crate) struct Dimensions {
|
||||
pub m: UInt,
|
||||
pub k: UInt,
|
||||
pub n: UInt,
|
||||
}
|
||||
|
||||
#[derive(CubeType, Copy, Clone)]
|
||||
|
@ -75,25 +72,15 @@ pub(crate) struct Coordinates {
|
|||
}
|
||||
|
||||
#[cube]
|
||||
fn get_info<F: Float>(lhs: &Tensor<F>, rhs: &Tensor<F>, out: &Tensor<F>) -> CubeTiling2dInfo {
|
||||
fn get_dims<F: Float>(lhs: &Tensor<F>, rhs: &Tensor<F>) -> Dimensions {
|
||||
let rank = lhs.rank();
|
||||
let first_dim = rank - UInt::new(2);
|
||||
let second_dim = rank - UInt::new(1);
|
||||
let dim_m = lhs.shape(first_dim);
|
||||
let dim_k = lhs.shape(second_dim);
|
||||
let dim_n = rhs.shape(second_dim);
|
||||
let lhs_stride = lhs.stride(first_dim);
|
||||
let rhs_stride = rhs.stride(first_dim);
|
||||
let out_stride = out.stride(first_dim);
|
||||
let m = lhs.shape(first_dim);
|
||||
let k = lhs.shape(second_dim);
|
||||
let n = rhs.shape(second_dim);
|
||||
|
||||
CubeTiling2dInfo {
|
||||
dim_m,
|
||||
dim_k,
|
||||
dim_n,
|
||||
lhs_stride,
|
||||
rhs_stride,
|
||||
out_stride,
|
||||
}
|
||||
Dimensions { m, k, n }
|
||||
}
|
||||
|
||||
#[cube]
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn_cube::prelude::*;
|
|||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::{
|
||||
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
|
||||
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
|
||||
compute_loop::{compute_loop, compute_loop_expand},
|
||||
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
|
||||
write_output::{write_to_output, write_to_output_expand},
|
||||
|
@ -18,7 +18,7 @@ pub(crate) fn block_loop<F: Float>(
|
|||
offsets: BatchOffsets,
|
||||
shared: SharedMemories<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
info: CubeTiling2dInfo,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let mut results = init_results::<F>(config);
|
||||
|
@ -28,7 +28,7 @@ pub(crate) fn block_loop<F: Float>(
|
|||
for k in range(0u32, n_loops, Comptime::new(false)) {
|
||||
let k = k * Comptime::runtime(block_size_k);
|
||||
|
||||
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, info);
|
||||
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, dims);
|
||||
|
||||
sync_units();
|
||||
|
||||
|
@ -37,14 +37,7 @@ pub(crate) fn block_loop<F: Float>(
|
|||
sync_units();
|
||||
}
|
||||
|
||||
write_to_output::<F>(
|
||||
out,
|
||||
&results,
|
||||
coordinates,
|
||||
offsets.out,
|
||||
info.out_stride,
|
||||
config,
|
||||
);
|
||||
write_to_output::<F>(out, &results, coordinates, offsets.out, dims.n, config);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
|
|
|
@ -3,7 +3,7 @@ use burn_cube::prelude::*;
|
|||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::{
|
||||
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
|
||||
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
|
||||
tile_read::{read_tile_from_global_memory, read_tile_from_global_memory_expand},
|
||||
tile_write::{
|
||||
write_tile_plain, write_tile_plain_expand, write_tile_transposed,
|
||||
|
@ -20,23 +20,23 @@ pub(crate) fn load_to_shared_memories<F: Float>(
|
|||
offsets: BatchOffsets,
|
||||
shared: SharedMemories<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
info: CubeTiling2dInfo,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed);
|
||||
let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed);
|
||||
|
||||
// Lhs must be loaded as transposed. If it already is in global memory, we load as plain.
|
||||
// Lhs must be loaded as transposed. If it already is transposed in global memory, we load as plain.
|
||||
if Comptime::get(lhs_transposed) {
|
||||
// load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
|
||||
load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims);
|
||||
} else {
|
||||
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
|
||||
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, dims);
|
||||
}
|
||||
|
||||
// Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back.
|
||||
if Comptime::get(rhs_transposed) {
|
||||
// load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
|
||||
load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims);
|
||||
} else {
|
||||
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
|
||||
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, dims);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,15 +48,15 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
batch_offset: UInt,
|
||||
shared_lhs: SharedMemory<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
info: CubeTiling2dInfo,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
|
||||
let sm_stride = Comptime::runtime(block_size_m);
|
||||
|
||||
let cube_offset = coordinates.skip_row * info.lhs_stride;
|
||||
let offset = cube_offset + k + batch_offset;
|
||||
let tensor_stride = dims.k;
|
||||
let offset = coordinates.skip_row * tensor_stride + k + batch_offset;
|
||||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
|
@ -68,9 +68,9 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
coordinates.unit_col,
|
||||
coordinates.skip_row,
|
||||
k,
|
||||
info.lhs_stride,
|
||||
info.dim_m,
|
||||
info.dim_k,
|
||||
tensor_stride,
|
||||
dims.m,
|
||||
dims.k,
|
||||
Comptime::map(config, |c| c.check_m_bounds),
|
||||
Comptime::map(config, |c| c.check_k_bounds),
|
||||
config,
|
||||
|
@ -86,6 +86,52 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn load_lhs_plain<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
coordinates: Coordinates,
|
||||
k: UInt,
|
||||
batch_offset: UInt,
|
||||
shared_lhs: SharedMemory<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
|
||||
let sm_stride = Comptime::runtime(block_size_m);
|
||||
|
||||
let tensor_stride = dims.m;
|
||||
let offset = coordinates.skip_row + k * tensor_stride + batch_offset;
|
||||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
read_tile_from_global_memory::<F>(
|
||||
lhs,
|
||||
&mut tile,
|
||||
offset,
|
||||
coordinates.unit_row,
|
||||
coordinates.unit_col,
|
||||
k,
|
||||
coordinates.skip_row,
|
||||
tensor_stride,
|
||||
dims.k,
|
||||
dims.m,
|
||||
Comptime::map(config, |c| c.check_k_bounds),
|
||||
Comptime::map(config, |c| c.check_m_bounds),
|
||||
config,
|
||||
);
|
||||
|
||||
write_tile_plain::<F>(
|
||||
&tile,
|
||||
shared_lhs,
|
||||
coordinates.unit_row,
|
||||
coordinates.unit_col,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn load_rhs_plain<F: Float>(
|
||||
rhs: &Tensor<F>,
|
||||
|
@ -94,14 +140,15 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
batch_offset: UInt,
|
||||
shared_rhs: SharedMemory<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
info: CubeTiling2dInfo,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
|
||||
let sm_stride = Comptime::runtime(block_size_n);
|
||||
|
||||
let offset = coordinates.skip_col + k * info.rhs_stride + batch_offset;
|
||||
let tensor_stride = dims.n;
|
||||
let offset = coordinates.skip_col + k * tensor_stride + batch_offset;
|
||||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
|
@ -113,9 +160,9 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
coordinates.unit_col,
|
||||
k,
|
||||
coordinates.skip_col,
|
||||
info.rhs_stride,
|
||||
info.dim_k,
|
||||
info.dim_n,
|
||||
tensor_stride,
|
||||
dims.k,
|
||||
dims.n,
|
||||
Comptime::map(config, |c| c.check_k_bounds),
|
||||
Comptime::map(config, |c| c.check_n_bounds),
|
||||
config,
|
||||
|
@ -131,6 +178,52 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn load_rhs_transposed<F: Float>(
|
||||
rhs: &Tensor<F>,
|
||||
coordinates: Coordinates,
|
||||
k: UInt,
|
||||
batch_offset: UInt,
|
||||
shared_rhs: SharedMemory<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
dims: Dimensions,
|
||||
) {
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
|
||||
let sm_stride = Comptime::runtime(block_size_n);
|
||||
|
||||
let tensor_stride = dims.k;
|
||||
let offset = coordinates.skip_col * tensor_stride + k + batch_offset;
|
||||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
read_tile_from_global_memory::<F>(
|
||||
rhs,
|
||||
&mut tile,
|
||||
offset,
|
||||
coordinates.unit_row,
|
||||
coordinates.unit_col,
|
||||
coordinates.skip_col,
|
||||
k,
|
||||
tensor_stride,
|
||||
dims.n,
|
||||
dims.k,
|
||||
Comptime::map(config, |c| c.check_n_bounds),
|
||||
Comptime::map(config, |c| c.check_k_bounds),
|
||||
config,
|
||||
);
|
||||
|
||||
write_tile_transposed::<F>(
|
||||
&tile,
|
||||
shared_rhs,
|
||||
coordinates.unit_col,
|
||||
coordinates.unit_row,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
/// Exported tests for loading to shared memory
|
||||
pub mod tests {
|
||||
|
@ -139,7 +232,7 @@ pub mod tests {
|
|||
};
|
||||
use crate::JitRuntime;
|
||||
|
||||
use super::{super::base::CoordinatesExpand, super::base::CubeTiling2dInfoExpand, *};
|
||||
use super::{super::base::CoordinatesExpand, super::base::DimensionsExpand, *};
|
||||
|
||||
#[cube(launch)]
|
||||
fn load_tensor_test<F: Float>(
|
||||
|
@ -168,24 +261,18 @@ pub mod tests {
|
|||
};
|
||||
|
||||
if Comptime::get(is_lhs) {
|
||||
let info = CubeTiling2dInfo {
|
||||
dim_m: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
dim_k: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
dim_n: UInt::new(0),
|
||||
lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
|
||||
rhs_stride: UInt::new(0),
|
||||
out_stride: UInt::new(0),
|
||||
let info = Dimensions {
|
||||
m: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
n: UInt::new(0),
|
||||
};
|
||||
|
||||
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info);
|
||||
} else {
|
||||
let info = CubeTiling2dInfo {
|
||||
dim_m: UInt::new(0),
|
||||
dim_k: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
dim_n: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
lhs_stride: UInt::new(0),
|
||||
rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
|
||||
out_stride: UInt::new(0),
|
||||
let info = Dimensions {
|
||||
m: UInt::new(0),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
n: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
};
|
||||
|
||||
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info);
|
||||
|
@ -196,6 +283,57 @@ pub mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn load_tensor_permuted_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
sm_out: &mut Array<F>,
|
||||
unit_row: UInt,
|
||||
unit_col: UInt,
|
||||
k: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
is_lhs: Comptime<bool>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let sm_size = block_size_k * block_size_m / tile_size;
|
||||
let shared_memory =
|
||||
SharedMemory::<F>::vectorized(Comptime::get(sm_size), Comptime::get(tile_size));
|
||||
|
||||
let offset = UInt::new(0);
|
||||
|
||||
let coordinates = Coordinates {
|
||||
unit_row,
|
||||
unit_col,
|
||||
skip_row: UInt::new(0),
|
||||
skip_col: UInt::new(0),
|
||||
};
|
||||
|
||||
if Comptime::get(is_lhs) {
|
||||
// Permuted
|
||||
let dims = Dimensions {
|
||||
m: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
n: UInt::new(0),
|
||||
};
|
||||
|
||||
load_lhs_plain(tensor, coordinates, k, offset, shared_memory, config, dims);
|
||||
} else {
|
||||
// Permuted
|
||||
let dims = Dimensions {
|
||||
m: UInt::new(0),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
n: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
};
|
||||
|
||||
load_rhs_transposed(tensor, coordinates, k, offset, shared_memory, config, dims);
|
||||
}
|
||||
|
||||
for i in range(0u32, Comptime::get(sm_size), Comptime::new(false)) {
|
||||
sm_out[i] = shared_memory[i];
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn load_tensor_multiple_tiles_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
|
@ -223,24 +361,18 @@ pub mod tests {
|
|||
};
|
||||
|
||||
if Comptime::get(is_lhs) {
|
||||
let info = CubeTiling2dInfo {
|
||||
dim_m: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
dim_k: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
dim_n: UInt::new(0),
|
||||
lhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
|
||||
rhs_stride: UInt::new(0),
|
||||
out_stride: UInt::new(0),
|
||||
let info = Dimensions {
|
||||
m: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
n: UInt::new(0),
|
||||
};
|
||||
|
||||
load_lhs_transposed(tensor, coordinates, k, offset, shared_memory, config, info);
|
||||
} else {
|
||||
let info = CubeTiling2dInfo {
|
||||
dim_m: UInt::new(0),
|
||||
dim_k: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
dim_n: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
lhs_stride: UInt::new(0),
|
||||
rhs_stride: tensor.stride(tensor.rank() - UInt::new(2)),
|
||||
out_stride: UInt::new(0),
|
||||
let info = Dimensions {
|
||||
m: UInt::new(0),
|
||||
k: tensor.shape(tensor.rank() - UInt::new(2)),
|
||||
n: tensor.shape(tensor.rank() - UInt::new(1)),
|
||||
};
|
||||
|
||||
load_rhs_plain(tensor, coordinates, k, offset, shared_memory, config, info);
|
||||
|
@ -468,4 +600,131 @@ pub mod tests {
|
|||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_lhs_plain_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let lhs = range_tensor::<R>(16, 16, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(TILE_SIZE as u8, &lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
4,
|
||||
4,
|
||||
8,
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 196.0, 197.0, 198.0, 199.0, 0.0, 0.0, 0.0, 0.0, 212.0, 213.0, 214.0, 215.0,
|
||||
0.0, 0.0, 0.0, 0.0, 228.0, 229.0, 230.0, 231.0, 0.0, 0.0, 0.0, 0.0, 244.0, 245.0,
|
||||
246.0, 247.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_lhs_plain_out_of_bounds_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let (m, k) = (6, 14);
|
||||
let lhs = range_tensor::<R>(k, m, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(m, k, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
4,
|
||||
4,
|
||||
8,
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 76.0, 77.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 82.0, 83.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_rhs_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let rhs = range_tensor::<R>(16, 16, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(TILE_SIZE as u8, &rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
4,
|
||||
4,
|
||||
8,
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 76.0, 92.0, 108.0, 124.0, 0.0, 0.0, 0.0, 0.0, 77.0, 93.0, 109.0, 125.0, 0.0,
|
||||
0.0, 0.0, 0.0, 78.0, 94.0, 110.0, 126.0, 0.0, 0.0, 0.0, 0.0, 79.0, 95.0, 111.0, 127.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_rhs_transposed_out_of_bounds_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let (k, n) = (14, 6);
|
||||
let rhs = range_tensor::<R>(n, k, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, k, n);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
4,
|
||||
4,
|
||||
8,
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 68.0, 82.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 69.0, 83.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -518,6 +518,54 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn swapped_lhs_row_col_large_uneven_m() {
|
||||
let (m, k, n) = (252, 256, 256);
|
||||
let swap_lhs = [2, 3];
|
||||
let swap_rhs = [0, 0];
|
||||
let shape_lhs = [3, 2, k, m];
|
||||
let shape_rhs = [3, 2, k, n];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
shape_rhs,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn swapped_rhs_row_col_large_uneven_n() {
|
||||
let (m, k, n) = (256, 256, 252);
|
||||
let swap_lhs = [0, 0];
|
||||
let swap_rhs = [2, 3];
|
||||
let shape_lhs = [3, 2, m, k];
|
||||
let shape_rhs = [3, 2, n, k];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
shape_rhs,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn swapped_both_row_col_large_uneven_k() {
|
||||
let (m, k, n) = (256, 252, 256);
|
||||
let swap_lhs = [2, 3];
|
||||
let swap_rhs = [2, 3];
|
||||
let shape_lhs = [3, 2, k, m];
|
||||
let shape_rhs = [3, 2, n, k];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dCube((Tiling2dConfig::default())),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
shape_rhs,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn swapped_row_with_batch_no_padding() {
|
||||
let swap_lhs = [0, 3];
|
||||
|
|
|
@ -99,6 +99,18 @@ mod tests {
|
|||
load_shared_memory_tests::load_lhs_transposed_cube_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_lhs_plain_unit_test() {
|
||||
load_shared_memory_tests::load_lhs_plain_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_lhs_plain_out_of_bounds_unit_test() {
|
||||
load_shared_memory_tests::load_lhs_plain_out_of_bounds_unit_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_lhs_transposed_out_of_bounds_cube_test() {
|
||||
load_shared_memory_tests::load_lhs_transposed_out_of_bounds_cube_test::<TestRuntime>(
|
||||
|
@ -128,6 +140,18 @@ mod tests {
|
|||
load_shared_memory_tests::load_rhs_plain_cube_offset_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_rhs_transposed_unit_test() {
|
||||
load_shared_memory_tests::load_rhs_transposed_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn load_rhs_transposed_out_of_bounds_unit_test() {
|
||||
load_shared_memory_tests::load_rhs_transposed_out_of_bounds_unit_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_results_inner_loop_unit_test() {
|
||||
write_output_tests::write_results_inner_loop_unit_test::<TestRuntime>(&Default::default())
|
||||
|
|
Loading…
Reference in New Issue