wip fixing last bug

This commit is contained in:
louisfd 2024-06-26 17:32:53 -04:00
parent 122b909e75
commit b93e984555
4 changed files with 237 additions and 25 deletions

View File

@ -150,23 +150,16 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let lhs = into_contiguous(lhs);
let rhs = into_contiguous(rhs);
// let lhs = match lhs.batch_swapped_with_row_col() {
// true => into_contiguous(lhs),
// false => lhs,
// };
// let rhs = match rhs.batch_swapped_with_row_col() {
// true => into_contiguous(rhs),
// false => rhs,
// };
config.block_size_m = 64;
config.block_size_n = 64;
config.block_size_k = 32; // k must be <= both m and n
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
let vectorization_factor = 4;
let x = (config.block_size_m / 4) as u32;
let y = (config.block_size_n / 4) as u32;
let vectorization_factor = 1;
let tile_size = 4;
let x = (config.block_size_m / tile_size) as u32;
let y = (config.block_size_n / tile_size) as u32;
let settings = KernelSettings::default()
.vectorize_input(0, vectorization_factor as u8)
.vectorize_input(1, vectorization_factor as u8)
@ -180,7 +173,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
TensorHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
TensorHandle::new(&out.handle, &out.strides, &out.shape.dims),
CubeTiling2dConfig::new(config, m, k, n, vectorization_factor as usize),
CubeTiling2dConfig::new(config, m, k, n, tile_size as usize),
);
out

View File

@ -122,12 +122,14 @@ fn load_tile<F: Float>(
let tensor_stride = tensor.stride(rank - UInt::new(2));
let tensor_position_base = load_row * tensor_stride + load_col + cube_offset;
// TODO should be needed only in some checks
let col = skip_col + load_col;
if Comptime::get(check_vertical_bounds) {
let row = skip_row + load_row;
let dim_vertical = tensor.shape(rank - UInt::new(2));
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
let dim_horizontal = tensor.shape(rank - UInt::new(1));
if col >= dim_horizontal {
@ -140,7 +142,9 @@ fn load_tile<F: Float>(
tensor_position_base,
tensor_stride,
tile,
col,
tile_size,
unroll,
);
}
} else {
@ -151,12 +155,13 @@ fn load_tile<F: Float>(
tensor_position_base,
tensor_stride,
tile,
col,
tile_size,
unroll,
);
}
} else {
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
let dim_horizontal = tensor.shape(rank - UInt::new(1));
if col >= dim_horizontal {
read_zeros::<F>(tile, tile_size, unroll);
@ -166,6 +171,7 @@ fn load_tile<F: Float>(
tensor_position_base,
tensor_stride,
tile,
col,
tile_size,
unroll,
);
@ -176,6 +182,7 @@ fn load_tile<F: Float>(
tensor_position_base,
tensor_stride,
tile,
col,
tile_size,
unroll,
);
@ -244,9 +251,10 @@ fn read_partial<F: Float>(
position_base: UInt,
stride: UInt,
mut tile: Array<F>,
col: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::runtime(Comptime::vectorization(tensor));
let tile_size_runtime = Comptime::runtime(tile_size);
let mut num_reads = UInt::new(0);
@ -255,7 +263,16 @@ fn read_partial<F: Float>(
}
for i in range(0u32, num_reads, Comptime::new(false)) {
tile[i] = tensor[(position_base + i * stride) / vectorization_factor];
read_inner::<F>(
tensor,
position_base,
stride,
tile,
i,
col,
tile_size,
unroll,
);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
@ -269,13 +286,77 @@ fn read_whole<F: Float>(
tensor: Tensor<F>,
position_base: UInt,
stride: UInt,
mut tile: Array<F>,
tile: Array<F>,
col: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::runtime(Comptime::vectorization(tensor));
for i in range(0u32, Comptime::get(tile_size), unroll) {
tile[i] = tensor[(position_base + i * stride) / vectorization_factor];
read_inner::<F>(
tensor,
position_base,
stride,
tile,
i,
col,
tile_size,
unroll,
);
}
}
#[cube]
fn read_inner<F: Float>(
tensor: Tensor<F>,
position_base: UInt,
stride: UInt,
mut tile: Array<F>,
i: UInt,
col: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::vectorization(tensor);
let position = position_base + i * stride;
if tile_size == vectorization_factor {
tile[i] = tensor[position / Comptime::runtime(vectorization_factor)];
} else {
let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1);
let mut tile_entry = F::vectorized(0., Comptime::get(tile_size));
let dim_horizontal = tensor.shape(tensor.rank() - UInt::new(1));
let mut num_reads = UInt::new(0);
if dim_horizontal > col {
num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size));
}
for x in range(
0u32,
Comptime::get(tile_size / vectorization_factor),
unroll,
) {
// TODO refactor
if Comptime::get(is_scalar) {
// TODO this happens even if no checks needed !!!
if x * Comptime::runtime(vectorization_factor) < num_reads {
tile_entry[x] = tensor[position + x];
}
} else {
// TODO this happens even if no checks needed !!!
if x * Comptime::runtime(vectorization_factor) < num_reads {
let intermediate =
tensor[position / Comptime::runtime(vectorization_factor) + x];
for y in range(0u32, Comptime::get(vectorization_factor), unroll) {
tile_entry[x * Comptime::runtime(vectorization_factor) + y] =
intermediate[y];
}
}
}
}
tile[i] = tile_entry;
}
}
@ -292,6 +373,7 @@ pub mod tests {
UInt::new(0),
tensor.stride(0),
tile,
UInt::new(0),
tile_size,
Comptime::new(true),
)
@ -311,7 +393,9 @@ pub mod tests {
UInt::new(8),
tensor.stride(0),
tile,
UInt::new(0),
tile_size,
Comptime::new(true),
)
}
@ -462,7 +546,7 @@ pub mod tests {
}
/// Exported test
pub fn read_whole_unit_test<R: JitRuntime>(device: &R::Device) {
pub fn read_whole_vectorized_like_tile_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
@ -498,6 +582,120 @@ pub mod tests {
assert_eq!(actual, expected);
}
/// Exported test
pub fn read_whole_vectorized_less_than_tile_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let vectorization_factor = 2;
let tensor = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..16, device)
.reshape([4, 4])
.float()
.into_primitive();
let client = R::client(device);
let tile = client.empty(tile_size * tile_size * core::mem::size_of::<f32>());
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(1, 1, 1))
.vectorize_input(0, vectorization_factor as u8)
.vectorize_output(0, tile_size as u8);
read_whole_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
ArrayHandle::new(&tile, 4),
tile_size.into(),
);
let actual = client.read(tile.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn read_whole_scalar_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let vectorization_factor = 1;
let tensor = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..16, device)
.reshape([4, 4])
.float()
.into_primitive();
let client = R::client(device);
let tile = client.empty(tile_size * tile_size * core::mem::size_of::<f32>());
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(1, 1, 1))
.vectorize_input(0, vectorization_factor as u8)
.vectorize_output(0, tile_size as u8);
read_whole_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
ArrayHandle::new(&tile, 4),
tile_size.into(),
);
let actual = client.read(tile.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn read_whole_scalar_out_of_bound_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;
let tile_size = 4;
let vectorization_factor = 2;
let tensor = burn_tensor::Tensor::<B<R>, 1, burn_tensor::Int>::arange(0..16, device)
.reshape([4, 2])
.float()
.into_primitive();
let client = R::client(device);
let tile = client.empty(tile_size * tile_size * core::mem::size_of::<f32>());
// Unit test
let cube_count = CubeCount::new(1, 1, 1);
let settings = KernelSettings::default()
.cube_dim(CubeDim::new(1, 1, 1))
.vectorize_input(0, vectorization_factor as u8)
.vectorize_output(0, tile_size as u8);
read_whole_test_launch::<F32, R>(
client.clone(),
cube_count,
settings,
TensorHandle::new(&tensor.handle, &tensor.strides, &tensor.shape.dims),
ArrayHandle::new(&tile, 4),
tile_size.into(),
);
let actual = client.read(tile.binding()).read_sync().unwrap();
let actual = f32::from_bytes(&actual);
let expected = &[
0.0, 1.0, 0.0, 0.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 6.0, 7.0, 0.0, 0.0,
];
assert_eq!(actual, expected);
}
/// Exported test
pub fn read_partial_unit_test<R: JitRuntime>(device: &R::Device) {
pub type B<R> = JitBackend<R, f32, i32>;

View File

@ -408,7 +408,7 @@ mod tests {
#[test]
pub fn straightforward() {
test_with_params(4, 4, 4, 1, 1);
test_with_params(9, 9, 9, 1, 1);
}
#[test]
@ -483,12 +483,12 @@ mod tests {
#[test]
pub fn large_k_multiple() {
test_with_params(256, 288, 264, 1, 1);
test_with_params(256, 256, 256, 1, 1);
}
#[test]
pub fn large_k_not_multiple() {
test_with_params(1000, 200, 1000, 5, 2);
test_with_params(257, 261, 254, 5, 2);
}
#[test]

View File

@ -38,8 +38,29 @@ mod tests {
}
#[test]
pub fn tiling2d_matmul_read_whole_vectorized_test() {
load_shared_memory_tests::read_whole_unit_test::<TestRuntime>(&Default::default())
pub fn tiling2d_matmul_read_whole_vectorized_like_tile_test() {
load_shared_memory_tests::read_whole_vectorized_like_tile_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn tiling2d_matmul_read_whole_vectorized_less_than_tile_test() {
load_shared_memory_tests::read_whole_vectorized_less_than_tile_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn tiling2d_matmul_read_whole_scalar_test() {
load_shared_memory_tests::read_whole_scalar_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn read_whole_scalar_out_of_bound_test() {
load_shared_memory_tests::read_whole_scalar_out_of_bound_test::<TestRuntime>(
&Default::default(),
)
}
#[test]