refactor load shared memory file

This commit is contained in:
louisfd 2024-07-04 10:52:54 -04:00
parent 74b06fbf55
commit cbb4de7156
8 changed files with 969 additions and 889 deletions

View File

@ -87,6 +87,7 @@ impl CubeTiling2dConfig {
&& config.block_size_n % config.tile_size == 0,
"Tiling 2d algorithm assumes tile size divides block size perfectly. "
);
CubeTiling2dConfig {
block_size_m: UInt::new(config.block_size_m as u32),
block_size_k: UInt::new(config.block_size_k as u32),

View File

@ -51,13 +51,13 @@ pub(crate) struct CubeTiling2dInfo {
pub out_stride: UInt,
}
#[derive(CubeType)]
#[derive(CubeType, Copy, Clone)]
pub(crate) struct SharedMemories<F: Float> {
pub lhs: SharedMemory<F>,
pub rhs: SharedMemory<F>,
}
#[derive(CubeType)]
#[derive(CubeType, Copy, Clone)]
/// Number of elements in previous batches
/// Not divided by vectorization facto
pub(crate) struct BatchOffsets {
@ -217,14 +217,38 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
.next()
.unwrap_or(1)
};
let lhs_vectorization = match lhs_transposed {
true => vectorization(m),
false => vectorization(k),
};
let rhs_vectorization = match rhs_transposed {
true => vectorization(k),
false => vectorization(n),
};
let out_vectorization = vectorization(n);
tiling2d_cube_launch::<E::FloatPrimitive, R>(
client,
tiling2d_cube_count(&out.shape, &config),
tiling2d_cube_dim(&config),
TensorArg::vectorized(vectorization(k), &lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorArg::vectorized(vectorization(n), &rhs.handle, &rhs.strides, &rhs.shape.dims),
TensorArg::vectorized(vectorization(n), &out.handle, &out.strides, &out.shape.dims),
TensorArg::vectorized(
lhs_vectorization,
&lhs.handle,
&lhs.strides,
&lhs.shape.dims,
),
TensorArg::vectorized(
rhs_vectorization,
&rhs.handle,
&rhs.strides,
&rhs.shape.dims,
),
TensorArg::vectorized(
out_vectorization,
&out.handle,
&out.strides,
&out.shape.dims,
),
CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed),
);

View File

@ -5,14 +5,11 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
compute_loop::{compute_loop, compute_loop_expand},
load_shared_memory::{
load_lhs_transposed, load_lhs_transposed_expand, load_rhs_plain, load_rhs_plain_expand,
},
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
write_output::{write_to_output, write_to_output_expand},
};
#[cube]
#[allow(unused_mut)]
pub(crate) fn block_loop<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
@ -31,8 +28,7 @@ pub(crate) fn block_loop<F: Float>(
for k in range(0u32, n_loops, Comptime::new(false)) {
let k = k * Comptime::runtime(block_size_k);
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, info);
sync_units();

View File

@ -2,7 +2,43 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::base::{Coordinates, CubeTiling2dInfo};
use super::{
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
tile_read::{read_tile_from_global_memory, read_tile_from_global_memory_expand},
tile_write::{
write_tile_plain, write_tile_plain_expand, write_tile_transposed,
write_tile_transposed_expand,
},
};
#[cube]
pub(crate) fn load_to_shared_memories<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
coordinates: Coordinates,
k: UInt,
offsets: BatchOffsets,
shared: SharedMemories<F>,
config: Comptime<CubeTiling2dConfig>,
info: CubeTiling2dInfo,
) {
let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed);
let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed);
// Lhs must be loaded as transposed. If it already is in global memory, we load as plain.
if Comptime::get(lhs_transposed) {
// load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
} else {
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
}
// Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back.
if Comptime::get(rhs_transposed) {
// load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
} else {
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
}
}
#[cube]
pub(crate) fn load_lhs_transposed<F: Float>(
@ -24,7 +60,7 @@ pub(crate) fn load_lhs_transposed<F: Float>(
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
load_tile::<F>(
read_tile_from_global_memory::<F>(
lhs,
&mut tile,
offset,
@ -69,7 +105,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
load_tile::<F>(
read_tile_from_global_memory::<F>(
rhs,
&mut tile,
offset,
@ -95,414 +131,6 @@ pub(crate) fn load_rhs_plain<F: Float>(
);
}
#[cube]
fn load_tile<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
cube_offset: UInt,
load_row: UInt,
load_col: UInt,
skip_row: UInt,
skip_col: UInt,
tensor_stride: UInt,
dim_vertical: UInt,
dim_horizontal: UInt,
check_vertical_bounds: Comptime<bool>,
check_horizontal_bounds: Comptime<bool>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
let tensor_position_base = load_row * tensor_stride + load_col + cube_offset;
if Comptime::get(check_vertical_bounds) {
let row = skip_row + load_row;
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
read_with_both_checks::<F>(
tensor,
row,
col,
tensor_position_base,
tensor_stride,
dim_vertical,
dim_horizontal,
tile,
tile_size,
unroll,
);
} else {
read_with_vertical_checks::<F>(
tensor,
row,
tensor_position_base,
tensor_stride,
dim_vertical,
tile,
tile_size,
unroll,
);
}
} else if Comptime::get(check_horizontal_bounds) {
let col = skip_col + load_col;
read_with_horizontal_checks::<F>(
tensor,
col,
tensor_position_base,
tensor_stride,
dim_horizontal,
tile,
tile_size,
unroll,
);
} else {
read_without_checks::<F>(
tensor,
tensor_position_base,
tensor_stride,
tile,
tile_size,
unroll,
);
}
}
#[cube]
fn write_tile_plain<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
write_row: UInt,
write_col: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
let tile_size_runtime = Comptime::runtime(tile_size);
let sm_position_base = write_row * sm_stride + write_col;
if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
}
}
}
#[cube]
fn write_tile_transposed<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
write_row: UInt,
write_col: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
let sm_position_base = write_row * sm_stride + write_col;
if Comptime::get(is_scalar) {
if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
shared_memory[sm_position_base] = tile[0];
}
} else {
shared_memory[sm_position_base] = tile[0];
}
} else if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
transpose_tile_to_shared_memory::<F>(
tile,
shared_memory,
sm_position_base,
sm_stride,
config,
);
}
} else {
transpose_tile_to_shared_memory::<F>(
tile,
shared_memory,
sm_position_base,
sm_stride,
config,
);
}
}
#[cube]
fn transpose_tile_to_shared_memory<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
sm_position_base: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = F::vectorized_empty(Comptime::get(tile_size));
// Unrolling this one makes the difference
for j in range(0u32, Comptime::get(tile_size), unroll) {
transposed[j] = tile[j][i];
}
let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size);
shared_memory[sm_position] = transposed;
}
}
#[cube]
fn read_with_both_checks<F: Float>(
tensor: &Tensor<F>,
row: UInt,
col: UInt,
position_base: UInt,
stride: UInt,
dim_vertical: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let tile_size_runtime = Comptime::runtime(tile_size);
let mut num_reads = UInt::new(0);
if dim_vertical > row {
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_with_checks::<F>(
tensor,
col,
position_base,
stride,
dim_horizontal,
tile,
i,
tile_size,
unroll,
);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
}
#[cube]
fn read_with_vertical_checks<F: Float>(
tensor: &Tensor<F>,
row: UInt,
position_base: UInt,
stride: UInt,
dim_vertical: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let tile_size_runtime = Comptime::runtime(tile_size);
let mut num_reads = UInt::new(0);
if dim_vertical > row {
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_without_checks::<F>(
tensor,
position_base,
stride,
tile,
i,
tile_size,
unroll,
);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
}
#[cube]
fn read_without_checks<F: Float>(
tensor: &Tensor<F>,
position_base: UInt,
stride: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
for i in range(0u32, Comptime::get(tile_size), unroll) {
read_tile_line_without_checks::<F>(
tensor,
position_base,
stride,
tile,
i,
tile_size,
unroll,
);
}
}
#[cube]
fn read_with_horizontal_checks<F: Float>(
tensor: &Tensor<F>,
col: UInt,
position_base: UInt,
stride: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
for i in range(0u32, Comptime::get(tile_size), unroll) {
read_tile_line_with_checks::<F>(
tensor,
col,
position_base,
stride,
dim_horizontal,
tile,
i,
tile_size,
unroll,
);
}
}
#[cube]
fn read_tile_line_with_checks<F: Float>(
tensor: &Tensor<F>,
col: UInt,
position_base: UInt,
stride: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
i: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::vectorization(tensor);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
let position = position_base + i * stride;
if tile_size == vectorization_factor {
if col >= dim_horizontal {
tile[i] = F::vectorized(0., Comptime::get(tile_size));
} else {
tile[i] = tensor[position / runtime_vectorization];
}
} else {
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
let mut num_loops = UInt::new(0);
if dim_horizontal > col {
let num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size));
num_loops = num_reads / runtime_vectorization;
}
for x in range(0u32, num_loops, Comptime::new(false)) {
read_within_vector::<F>(
tensor,
tile_entry,
position,
x,
vectorization_factor,
unroll,
);
}
tile[i] = tile_entry;
}
}
#[cube]
fn read_tile_line_without_checks<F: Float>(
tensor: &Tensor<F>,
position_base: UInt,
stride: UInt,
tile: &mut Array<F>,
i: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::vectorization(tensor);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
let position = position_base + i * stride;
if tile_size == vectorization_factor {
tile[i] = tensor[position / runtime_vectorization];
} else {
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
for j in range(
0u32,
Comptime::get(tile_size / vectorization_factor),
unroll,
) {
read_within_vector::<F>(
tensor,
tile_entry,
position,
j,
vectorization_factor,
unroll,
);
}
tile[i] = tile_entry;
}
}
#[cube]
/// Necessary when vectorization_factor < tile_size
fn read_within_vector<F: Float>(
tensor: &Tensor<F>,
mut tile_entry: F,
position: UInt,
i: UInt,
vectorization_factor: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
if Comptime::get(is_scalar) {
tile_entry[i] = tensor[position + i];
} else {
let intermediate = tensor[position / runtime_vectorization + i];
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
tile_entry[i * runtime_vectorization + j] = intermediate[j];
}
}
}
#[cfg(feature = "export_tests")]
/// Exported tests for loading to shared memory
pub mod tests {
@ -513,145 +141,6 @@ pub mod tests {
use super::{super::base::CoordinatesExpand, super::base::CubeTiling2dInfoExpand, *};
#[cube(launch)]
#[allow(unused_mut)]
fn read_whole_test<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
bound_check_horizontal: Comptime<bool>,
) {
if Comptime::get(bound_check_horizontal) {
read_with_horizontal_checks::<F>(
tensor,
UInt::new(0),
UInt::new(0),
tensor.stride(0),
tensor.shape(1),
tile,
tile_size,
Comptime::new(true),
);
} else {
read_without_checks::<F>(
tensor,
UInt::new(0),
tensor.stride(0),
tile,
tile_size,
Comptime::new(true),
);
}
}
#[cube(launch)]
#[allow(unused_mut)]
fn read_partial_test<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
bound_check_horizontal: Comptime<bool>,
) {
if Comptime::get(bound_check_horizontal) {
read_with_both_checks::<F>(
tensor,
UInt::new(2),
UInt::new(8),
UInt::new(0),
tensor.stride(0),
tensor.shape(0),
tensor.shape(1),
tile,
tile_size,
Comptime::new(true),
);
} else {
read_with_vertical_checks::<F>(
tensor,
UInt::new(2),
UInt::new(8),
tensor.stride(0),
tensor.shape(0),
tile,
tile_size,
Comptime::new(true),
);
}
}
#[cube(launch)]
#[allow(unused_mut)]
fn load_tile_test<F: Float>(
lhs: &Tensor<F>,
tile: &mut Array<F>,
unit_row: UInt,
unit_col: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let cube_offset = UInt::new(0);
let check_vertical_bounds = Comptime::map(config, |c| c.check_m_bounds);
let check_horizontal_bounds = Comptime::map(config, |c| c.check_k_bounds);
let lhs_stride = lhs.stride(lhs.rank() - UInt::new(2));
let dim_m = lhs.shape(lhs.rank() - UInt::new(2));
let dim_k = lhs.shape(lhs.rank() - UInt::new(1));
load_tile::<F>(
lhs,
tile,
cube_offset,
unit_row,
unit_col,
UInt::new(0),
UInt::new(0),
lhs_stride,
dim_m,
dim_k,
check_vertical_bounds,
check_horizontal_bounds,
config,
);
}
#[cube(launch)]
fn write_tile_test<F: Float>(
tile: &Array<F>,
sm_out: &mut Array<F>,
config: Comptime<CubeTiling2dConfig>,
transposed: Comptime<bool>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let sm_stride = block_size_m;
let sm_size = Comptime::runtime(block_size_k * block_size_m);
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
if Comptime::get(transposed) {
write_tile_transposed(
tile,
shared_memory,
UInt::new(0),
UInt::new(0),
Comptime::runtime(sm_stride),
config,
);
} else {
write_tile_plain(
tile,
shared_memory,
UInt::new(0),
UInt::new(0),
Comptime::runtime(sm_stride),
config,
);
}
for i in range(0u32, sm_size, Comptime::new(false)) {
sm_out[i] = shared_memory[i];
}
}
#[cube(launch)]
fn load_tensor_test<F: Float>(
tensor: &Tensor<F>,
@ -762,314 +251,6 @@ pub mod tests {
}
}
/// Exported test
pub fn read_whole_vectorized_like_tile_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_vectorized_less_than_tile_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 2;
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_scalar_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 1;
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_scalar_out_of_bound_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 2;
let tensor = range_tensor::<R>(4, 2, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
true,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 0.0, 0.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 6.0, 7.0, 0.0, 0.0,
],
device,
);
}
/// Exported test
pub fn read_partial_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(4, 4, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_partial_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
let expected = &[
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn load_tile_no_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(8, 8, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
load_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
0,
0,
config,
);
let expected = &[
0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0, 16.0, 17.0, 18.0, 19.0, 24.0, 25.0, 26.0,
27.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn load_tile_vertical_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(6, 8, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(6, 8, 8);
load_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
4,
0,
config,
);
let expected = &[
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn load_tile_horizontal_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(8, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 4, 8);
load_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
0,
4,
config,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn write_tile_plain_unit_test<R: JitRuntime>(device: &R::Device) {
let tile = range_tensor::<R>(4, 4, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
write_tile_test_launch::<F32, R>(
tile.client.clone(),
cube_count,
cube_dim,
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 4),
config,
false,
);
let expected = &[
0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0,
9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn write_tile_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
let tile = range_tensor::<R>(4, 4, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
write_tile_test_launch::<F32, R>(
tile.client.clone(),
cube_count,
cube_dim,
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
config,
true,
);
let expected = &[
0.0, 4.0, 8.0, 12.0, 0.0, 0.0, 0.0, 0.0, 1.0, 5.0, 9.0, 13.0, 0.0, 0.0, 0.0, 0.0, 2.0,
6.0, 10.0, 14.0, 0.0, 0.0, 0.0, 0.0, 3.0, 7.0, 11.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn load_lhs_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
let lhs = range_tensor::<R>(16, 16, device);

View File

@ -5,6 +5,8 @@ mod load_shared_memory;
mod outer_product;
#[cfg(feature = "export_tests")]
mod test_utils;
mod tile_read;
mod tile_write;
mod write_output;
pub use base::matmul_tiling_2d_cube;
@ -13,5 +15,6 @@ pub use base::matmul_tiling_2d_cube;
pub use {
compute_loop::tests as compute_loop_tests,
load_shared_memory::tests as load_shared_memory_tests,
outer_product::tests as outer_product_tests, write_output::tests as write_output_tests,
outer_product::tests as outer_product_tests, tile_read::tests as tile_read_tests,
tile_write::tests as tile_write_tests, write_output::tests as write_output_tests,
};

View File

@ -0,0 +1,672 @@
use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
#[cube]
pub(crate) fn read_tile_from_global_memory<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
cube_offset: UInt,
read_row: UInt,
read_col: UInt,
skip_row: UInt,
skip_col: UInt,
tensor_stride: UInt,
dim_vertical: UInt,
dim_horizontal: UInt,
check_vertical_bounds: Comptime<bool>,
check_horizontal_bounds: Comptime<bool>,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
let tensor_position_base = read_row * tensor_stride + read_col + cube_offset;
if Comptime::get(check_vertical_bounds) {
let row = skip_row + read_row;
if Comptime::get(check_horizontal_bounds) {
let col = skip_col + read_col;
read_with_both_checks::<F>(
tensor,
row,
col,
tensor_position_base,
tensor_stride,
dim_vertical,
dim_horizontal,
tile,
tile_size,
unroll,
);
} else {
read_with_vertical_checks::<F>(
tensor,
row,
tensor_position_base,
tensor_stride,
dim_vertical,
tile,
tile_size,
unroll,
);
}
} else if Comptime::get(check_horizontal_bounds) {
let col = skip_col + read_col;
read_with_horizontal_checks::<F>(
tensor,
col,
tensor_position_base,
tensor_stride,
dim_horizontal,
tile,
tile_size,
unroll,
);
} else {
read_without_checks::<F>(
tensor,
tensor_position_base,
tensor_stride,
tile,
tile_size,
unroll,
);
}
}
#[cube]
fn read_with_both_checks<F: Float>(
tensor: &Tensor<F>,
row: UInt,
col: UInt,
position_base: UInt,
stride: UInt,
dim_vertical: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let tile_size_runtime = Comptime::runtime(tile_size);
let mut num_reads = UInt::new(0);
if dim_vertical > row {
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_with_checks::<F>(
tensor,
col,
position_base,
stride,
dim_horizontal,
tile,
i,
tile_size,
unroll,
);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
}
#[cube]
fn read_with_vertical_checks<F: Float>(
tensor: &Tensor<F>,
row: UInt,
position_base: UInt,
stride: UInt,
dim_vertical: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let tile_size_runtime = Comptime::runtime(tile_size);
let mut num_reads = UInt::new(0);
if dim_vertical > row {
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_without_checks::<F>(
tensor,
position_base,
stride,
tile,
i,
tile_size,
unroll,
);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
}
#[cube]
fn read_without_checks<F: Float>(
tensor: &Tensor<F>,
position_base: UInt,
stride: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
for i in range(0u32, Comptime::get(tile_size), unroll) {
read_tile_line_without_checks::<F>(
tensor,
position_base,
stride,
tile,
i,
tile_size,
unroll,
);
}
}
#[cube]
fn read_with_horizontal_checks<F: Float>(
tensor: &Tensor<F>,
col: UInt,
position_base: UInt,
stride: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
for i in range(0u32, Comptime::get(tile_size), unroll) {
read_tile_line_with_checks::<F>(
tensor,
col,
position_base,
stride,
dim_horizontal,
tile,
i,
tile_size,
unroll,
);
}
}
#[cube]
fn read_tile_line_with_checks<F: Float>(
tensor: &Tensor<F>,
col: UInt,
position_base: UInt,
stride: UInt,
dim_horizontal: UInt,
tile: &mut Array<F>,
i: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::vectorization(tensor);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
let position = position_base + i * stride;
if tile_size == vectorization_factor {
if col >= dim_horizontal {
tile[i] = F::vectorized(0., Comptime::get(tile_size));
} else {
tile[i] = tensor[position / runtime_vectorization];
}
} else {
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
let mut num_loops = UInt::new(0);
if dim_horizontal > col {
let num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size));
num_loops = num_reads / runtime_vectorization;
}
for x in range(0u32, num_loops, Comptime::new(false)) {
read_within_vector::<F>(
tensor,
tile_entry,
position,
x,
vectorization_factor,
unroll,
);
}
tile[i] = tile_entry;
}
}
#[cube]
fn read_tile_line_without_checks<F: Float>(
tensor: &Tensor<F>,
position_base: UInt,
stride: UInt,
tile: &mut Array<F>,
i: UInt,
tile_size: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let vectorization_factor = Comptime::vectorization(tensor);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
let position = position_base + i * stride;
if tile_size == vectorization_factor {
tile[i] = tensor[position / runtime_vectorization];
} else {
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
for j in range(
0u32,
Comptime::get(tile_size / vectorization_factor),
unroll,
) {
read_within_vector::<F>(
tensor,
tile_entry,
position,
j,
vectorization_factor,
unroll,
);
}
tile[i] = tile_entry;
}
}
#[cube]
/// Necessary when vectorization_factor < tile_size
fn read_within_vector<F: Float>(
tensor: &Tensor<F>,
mut tile_entry: F,
position: UInt,
i: UInt,
vectorization_factor: Comptime<UInt>,
unroll: Comptime<bool>,
) {
let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1);
let runtime_vectorization = Comptime::runtime(vectorization_factor);
if Comptime::get(is_scalar) {
tile_entry[i] = tensor[position + i];
} else {
let intermediate = tensor[position / runtime_vectorization + i];
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
tile_entry[i * runtime_vectorization + j] = intermediate[j];
}
}
}
#[cfg(feature = "export_tests")]
/// Exported tests for reading tiles in global memory
pub mod tests {
use crate::kernel::matmul::tiling2d_cube::test_utils::{
assert_equals, create_empty, make_config, range_tensor, TILE_SIZE,
};
use crate::JitRuntime;
use super::*;
#[cube(launch)]
#[allow(unused_mut)]
fn read_whole_test<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
bound_check_horizontal: Comptime<bool>,
) {
if Comptime::get(bound_check_horizontal) {
read_with_horizontal_checks::<F>(
tensor,
UInt::new(0),
UInt::new(0),
tensor.stride(0),
tensor.shape(1),
tile,
tile_size,
Comptime::new(true),
);
} else {
read_without_checks::<F>(
tensor,
UInt::new(0),
tensor.stride(0),
tile,
tile_size,
Comptime::new(true),
);
}
}
#[cube(launch)]
#[allow(unused_mut)]
fn read_partial_test<F: Float>(
tensor: &Tensor<F>,
tile: &mut Array<F>,
tile_size: Comptime<UInt>,
bound_check_horizontal: Comptime<bool>,
) {
if Comptime::get(bound_check_horizontal) {
read_with_both_checks::<F>(
tensor,
UInt::new(2),
UInt::new(8),
UInt::new(0),
tensor.stride(0),
tensor.shape(0),
tensor.shape(1),
tile,
tile_size,
Comptime::new(true),
);
} else {
read_with_vertical_checks::<F>(
tensor,
UInt::new(2),
UInt::new(8),
tensor.stride(0),
tensor.shape(0),
tile,
tile_size,
Comptime::new(true),
);
}
}
#[cube(launch)]
fn read_tile_test<F: Float>(
lhs: &Tensor<F>,
tile: &mut Array<F>,
unit_row: UInt,
unit_col: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let cube_offset = UInt::new(0);
let check_vertical_bounds = Comptime::map(config, |c| c.check_m_bounds);
let check_horizontal_bounds = Comptime::map(config, |c| c.check_k_bounds);
let lhs_stride = lhs.stride(lhs.rank() - UInt::new(2));
let dim_m = lhs.shape(lhs.rank() - UInt::new(2));
let dim_k = lhs.shape(lhs.rank() - UInt::new(1));
read_tile_from_global_memory::<F>(
lhs,
tile,
cube_offset,
unit_row,
unit_col,
UInt::new(0),
UInt::new(0),
lhs_stride,
dim_m,
dim_k,
check_vertical_bounds,
check_horizontal_bounds,
config,
);
}
/// Exported test
pub fn read_whole_vectorized_like_tile_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_vectorized_less_than_tile_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 2;
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_scalar_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 1;
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
15.0,
],
device,
);
}
/// Exported test
pub fn read_whole_scalar_out_of_bound_test<R: JitRuntime>(device: &R::Device) {
let vectorization_factor = 2;
let tensor = range_tensor::<R>(4, 2, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_whole_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
vectorization_factor,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
true,
);
assert_equals::<R>(
tile,
&[
0.0, 1.0, 0.0, 0.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 6.0, 7.0, 0.0, 0.0,
],
device,
);
}
/// Exported test
pub fn read_partial_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(4, 4, device);
let tile = create_empty::<R>(4, 4, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
read_partial_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
TILE_SIZE.into(),
false,
);
let expected = &[
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn read_tile_no_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(8, 8, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
read_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
0,
0,
config,
);
let expected = &[
0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0, 16.0, 17.0, 18.0, 19.0, 24.0, 25.0, 26.0,
27.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn read_tile_vertical_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(6, 8, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(6, 8, 8);
read_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
4,
0,
config,
);
let expected = &[
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
/// Exported test
pub fn read_tile_horizontal_checks_unit_test<R: JitRuntime>(device: &R::Device) {
let tensor = range_tensor::<R>(8, 4, device);
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 4, 8);
read_tile_test_launch::<F32, R>(
tensor.client.clone(),
cube_count,
cube_dim,
TensorArg::vectorized(
TILE_SIZE as u8,
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
),
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
0,
4,
config,
);
let expected = &[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(tile, expected, device);
}
}

View File

@ -0,0 +1,210 @@
use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
#[cube]
pub(crate) fn write_tile_plain<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
write_row: UInt,
write_col: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
let tile_size_runtime = Comptime::runtime(tile_size);
let sm_position_base = write_row * sm_stride + write_col;
if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
}
}
}
#[cube]
pub(crate) fn write_tile_transposed<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
write_row: UInt,
write_col: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
let sm_position_base = write_row * sm_stride + write_col;
if Comptime::get(is_scalar) {
if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
shared_memory[sm_position_base] = tile[0];
}
} else {
shared_memory[sm_position_base] = tile[0];
}
} else if Comptime::get(check_sm_bounds) {
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
if write_row < sm_dim_vertical {
transpose_tile_to_shared_memory::<F>(
tile,
shared_memory,
sm_position_base,
sm_stride,
config,
);
}
} else {
transpose_tile_to_shared_memory::<F>(
tile,
shared_memory,
sm_position_base,
sm_stride,
config,
);
}
}
#[cube]
fn transpose_tile_to_shared_memory<F: Float>(
tile: &Array<F>,
mut shared_memory: SharedMemory<F>,
sm_position_base: UInt,
sm_stride: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let unroll = Comptime::map(config, |c| c.unroll_tile);
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = F::vectorized_empty(Comptime::get(tile_size));
// Unrolling this one makes the difference
for j in range(0u32, Comptime::get(tile_size), unroll) {
transposed[j] = tile[j][i];
}
let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size);
shared_memory[sm_position] = transposed;
}
}
#[cfg(feature = "export_tests")]
/// Exported tests for writing tiles to shared memory
pub mod tests {
use crate::kernel::matmul::tiling2d_cube::test_utils::{
assert_equals, create_empty, make_config, range_tensor, TILE_SIZE,
};
use crate::JitRuntime;
use super::*;
#[cube(launch)]
fn write_tile_test<F: Float>(
tile: &Array<F>,
sm_out: &mut Array<F>,
config: Comptime<CubeTiling2dConfig>,
transposed: Comptime<bool>,
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let sm_stride = block_size_m;
let sm_size = Comptime::runtime(block_size_k * block_size_m);
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
if Comptime::get(transposed) {
write_tile_transposed(
tile,
shared_memory,
UInt::new(0),
UInt::new(0),
Comptime::runtime(sm_stride),
config,
);
} else {
write_tile_plain(
tile,
shared_memory,
UInt::new(0),
UInt::new(0),
Comptime::runtime(sm_stride),
config,
);
}
for i in range(0u32, sm_size, Comptime::new(false)) {
sm_out[i] = shared_memory[i];
}
}
/// Exported test
pub fn write_tile_plain_unit_test<R: JitRuntime>(device: &R::Device) {
let tile = range_tensor::<R>(4, 4, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
write_tile_test_launch::<F32, R>(
tile.client.clone(),
cube_count,
cube_dim,
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 4),
config,
false,
);
let expected = &[
0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0,
9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
/// Exported test
pub fn write_tile_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
let tile = range_tensor::<R>(4, 4, device);
let sm_out = create_empty::<R>(8, 8, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
let config = make_config(8, 8, 8);
write_tile_test_launch::<F32, R>(
tile.client.clone(),
cube_count,
cube_dim,
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
config,
true,
);
let expected = &[
0.0, 4.0, 8.0, 12.0, 0.0, 0.0, 0.0, 0.0, 1.0, 5.0, 9.0, 13.0, 0.0, 0.0, 0.0, 0.0, 2.0,
6.0, 10.0, 14.0, 0.0, 0.0, 0.0, 0.0, 3.0, 7.0, 11.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
assert_equals::<R>(sm_out, expected, device);
}
}

View File

@ -2,7 +2,8 @@
mod tests {
use super::*;
use burn_jit::kernel::matmul::tiling2d_cube::{
compute_loop_tests, load_shared_memory_tests, outer_product_tests, write_output_tests,
compute_loop_tests, load_shared_memory_tests, outer_product_tests, tile_read_tests,
tile_write_tests, write_output_tests,
};
use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig};
use burn_tensor::{Shape, Tensor};
@ -38,62 +39,54 @@ mod tests {
#[test]
pub fn tiling2d_matmul_read_whole_vectorized_like_tile_test() {
load_shared_memory_tests::read_whole_vectorized_like_tile_test::<TestRuntime>(
&Default::default(),
)
tile_read_tests::read_whole_vectorized_like_tile_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_read_whole_vectorized_less_than_tile_test() {
load_shared_memory_tests::read_whole_vectorized_less_than_tile_test::<TestRuntime>(
tile_read_tests::read_whole_vectorized_less_than_tile_test::<TestRuntime>(
&Default::default(),
)
}
#[test]
pub fn tiling2d_matmul_read_whole_scalar_test() {
load_shared_memory_tests::read_whole_scalar_test::<TestRuntime>(&Default::default())
tile_read_tests::read_whole_scalar_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn read_whole_scalar_out_of_bound_test() {
load_shared_memory_tests::read_whole_scalar_out_of_bound_test::<TestRuntime>(
&Default::default(),
)
tile_read_tests::read_whole_scalar_out_of_bound_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_read_partial_unit_test() {
load_shared_memory_tests::read_partial_unit_test::<TestRuntime>(&Default::default())
tile_read_tests::read_partial_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_load_tile_no_checks_unit_test() {
load_shared_memory_tests::load_tile_no_checks_unit_test::<TestRuntime>(&Default::default())
tile_read_tests::read_tile_no_checks_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_load_tile_vertical_checks_unit_test() {
load_shared_memory_tests::load_tile_vertical_checks_unit_test::<TestRuntime>(
&Default::default(),
)
tile_read_tests::read_tile_vertical_checks_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn tiling2d_matmul_load_tile_horizontal_checks_unit_test() {
load_shared_memory_tests::load_tile_horizontal_checks_unit_test::<TestRuntime>(
&Default::default(),
)
tile_read_tests::read_tile_horizontal_checks_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn write_tile_plain_unit_test() {
load_shared_memory_tests::write_tile_plain_unit_test::<TestRuntime>(&Default::default())
tile_write_tests::write_tile_plain_unit_test::<TestRuntime>(&Default::default())
}
#[test]
pub fn write_tile_transposed_unit_test() {
load_shared_memory_tests::write_tile_transposed_unit_test::<TestRuntime>(&Default::default())
tile_write_tests::write_tile_transposed_unit_test::<TestRuntime>(&Default::default())
}
#[test]