mirror of https://github.com/tracel-ai/burn.git
refactor load shared memory file
This commit is contained in:
parent
74b06fbf55
commit
cbb4de7156
|
@ -87,6 +87,7 @@ impl CubeTiling2dConfig {
|
|||
&& config.block_size_n % config.tile_size == 0,
|
||||
"Tiling 2d algorithm assumes tile size divides block size perfectly. "
|
||||
);
|
||||
|
||||
CubeTiling2dConfig {
|
||||
block_size_m: UInt::new(config.block_size_m as u32),
|
||||
block_size_k: UInt::new(config.block_size_k as u32),
|
||||
|
|
|
@ -51,13 +51,13 @@ pub(crate) struct CubeTiling2dInfo {
|
|||
pub out_stride: UInt,
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
#[derive(CubeType, Copy, Clone)]
|
||||
pub(crate) struct SharedMemories<F: Float> {
|
||||
pub lhs: SharedMemory<F>,
|
||||
pub rhs: SharedMemory<F>,
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
#[derive(CubeType, Copy, Clone)]
|
||||
/// Number of elements in previous batches
|
||||
/// Not divided by vectorization facto
|
||||
pub(crate) struct BatchOffsets {
|
||||
|
@ -217,14 +217,38 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
.next()
|
||||
.unwrap_or(1)
|
||||
};
|
||||
let lhs_vectorization = match lhs_transposed {
|
||||
true => vectorization(m),
|
||||
false => vectorization(k),
|
||||
};
|
||||
let rhs_vectorization = match rhs_transposed {
|
||||
true => vectorization(k),
|
||||
false => vectorization(n),
|
||||
};
|
||||
let out_vectorization = vectorization(n);
|
||||
|
||||
tiling2d_cube_launch::<E::FloatPrimitive, R>(
|
||||
client,
|
||||
tiling2d_cube_count(&out.shape, &config),
|
||||
tiling2d_cube_dim(&config),
|
||||
TensorArg::vectorized(vectorization(k), &lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorArg::vectorized(vectorization(n), &rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
TensorArg::vectorized(vectorization(n), &out.handle, &out.strides, &out.shape.dims),
|
||||
TensorArg::vectorized(
|
||||
lhs_vectorization,
|
||||
&lhs.handle,
|
||||
&lhs.strides,
|
||||
&lhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
rhs_vectorization,
|
||||
&rhs.handle,
|
||||
&rhs.strides,
|
||||
&rhs.shape.dims,
|
||||
),
|
||||
TensorArg::vectorized(
|
||||
out_vectorization,
|
||||
&out.handle,
|
||||
&out.strides,
|
||||
&out.shape.dims,
|
||||
),
|
||||
CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed),
|
||||
);
|
||||
|
||||
|
|
|
@ -5,14 +5,11 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
|
|||
use super::{
|
||||
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
|
||||
compute_loop::{compute_loop, compute_loop_expand},
|
||||
load_shared_memory::{
|
||||
load_lhs_transposed, load_lhs_transposed_expand, load_rhs_plain, load_rhs_plain_expand,
|
||||
},
|
||||
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
|
||||
write_output::{write_to_output, write_to_output_expand},
|
||||
};
|
||||
|
||||
#[cube]
|
||||
#[allow(unused_mut)]
|
||||
pub(crate) fn block_loop<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
|
@ -31,8 +28,7 @@ pub(crate) fn block_loop<F: Float>(
|
|||
for k in range(0u32, n_loops, Comptime::new(false)) {
|
||||
let k = k * Comptime::runtime(block_size_k);
|
||||
|
||||
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
|
||||
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
|
||||
load_to_shared_memories(lhs, rhs, coordinates, k, offsets, shared, config, info);
|
||||
|
||||
sync_units();
|
||||
|
||||
|
|
|
@ -2,7 +2,43 @@ use burn_cube::prelude::*;
|
|||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::base::{Coordinates, CubeTiling2dInfo};
|
||||
use super::{
|
||||
base::{BatchOffsets, Coordinates, CubeTiling2dInfo, SharedMemories},
|
||||
tile_read::{read_tile_from_global_memory, read_tile_from_global_memory_expand},
|
||||
tile_write::{
|
||||
write_tile_plain, write_tile_plain_expand, write_tile_transposed,
|
||||
write_tile_transposed_expand,
|
||||
},
|
||||
};
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn load_to_shared_memories<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
coordinates: Coordinates,
|
||||
k: UInt,
|
||||
offsets: BatchOffsets,
|
||||
shared: SharedMemories<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
info: CubeTiling2dInfo,
|
||||
) {
|
||||
let lhs_transposed = Comptime::map(config, |c| c.lhs_transposed);
|
||||
let rhs_transposed = Comptime::map(config, |c| c.rhs_transposed);
|
||||
|
||||
// Lhs must be loaded as transposed. If it already is in global memory, we load as plain.
|
||||
if Comptime::get(lhs_transposed) {
|
||||
// load_lhs_plain::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
|
||||
} else {
|
||||
load_lhs_transposed::<F>(lhs, coordinates, k, offsets.lhs, shared.lhs, config, info);
|
||||
}
|
||||
|
||||
// Rhs must be loaded as plain. If it is transposed in global memory, we transpose it back.
|
||||
if Comptime::get(rhs_transposed) {
|
||||
// load_rhs_transposed::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
|
||||
} else {
|
||||
load_rhs_plain::<F>(rhs, coordinates, k, offsets.rhs, shared.rhs, config, info);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn load_lhs_transposed<F: Float>(
|
||||
|
@ -24,7 +60,7 @@ pub(crate) fn load_lhs_transposed<F: Float>(
|
|||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
load_tile::<F>(
|
||||
read_tile_from_global_memory::<F>(
|
||||
lhs,
|
||||
&mut tile,
|
||||
offset,
|
||||
|
@ -69,7 +105,7 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
|
||||
let mut tile = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
load_tile::<F>(
|
||||
read_tile_from_global_memory::<F>(
|
||||
rhs,
|
||||
&mut tile,
|
||||
offset,
|
||||
|
@ -95,414 +131,6 @@ pub(crate) fn load_rhs_plain<F: Float>(
|
|||
);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn load_tile<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
cube_offset: UInt,
|
||||
load_row: UInt,
|
||||
load_col: UInt,
|
||||
skip_row: UInt,
|
||||
skip_col: UInt,
|
||||
tensor_stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
dim_horizontal: UInt,
|
||||
check_vertical_bounds: Comptime<bool>,
|
||||
check_horizontal_bounds: Comptime<bool>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
|
||||
let tensor_position_base = load_row * tensor_stride + load_col + cube_offset;
|
||||
|
||||
if Comptime::get(check_vertical_bounds) {
|
||||
let row = skip_row + load_row;
|
||||
|
||||
if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + load_col;
|
||||
read_with_both_checks::<F>(
|
||||
tensor,
|
||||
row,
|
||||
col,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_vertical,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
} else {
|
||||
read_with_vertical_checks::<F>(
|
||||
tensor,
|
||||
row,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_vertical,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
} else if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + load_col;
|
||||
read_with_horizontal_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
} else {
|
||||
read_without_checks::<F>(
|
||||
tensor,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn write_tile_plain<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
write_row: UInt,
|
||||
write_col: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let sm_position_base = write_row * sm_stride + write_col;
|
||||
|
||||
if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn write_tile_transposed<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
write_row: UInt,
|
||||
write_col: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
|
||||
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
|
||||
|
||||
let sm_position_base = write_row * sm_stride + write_col;
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
shared_memory[sm_position_base] = tile[0];
|
||||
}
|
||||
} else {
|
||||
shared_memory[sm_position_base] = tile[0];
|
||||
}
|
||||
} else if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
transpose_tile_to_shared_memory::<F>(
|
||||
tile,
|
||||
shared_memory,
|
||||
sm_position_base,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
transpose_tile_to_shared_memory::<F>(
|
||||
tile,
|
||||
shared_memory,
|
||||
sm_position_base,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn transpose_tile_to_shared_memory<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
sm_position_base: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mut transposed = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
// Unrolling this one makes the difference
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
transposed[j] = tile[j][i];
|
||||
}
|
||||
|
||||
let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size);
|
||||
shared_memory[sm_position] = transposed;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_both_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
row: UInt,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let mut num_reads = UInt::new(0);
|
||||
if dim_vertical > row {
|
||||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_with_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
position_base,
|
||||
stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_vertical_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
row: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let mut num_reads = UInt::new(0);
|
||||
if dim_vertical > row {
|
||||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_without_checks::<F>(
|
||||
tensor,
|
||||
position_base,
|
||||
stride,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_without_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
read_tile_line_without_checks::<F>(
|
||||
tensor,
|
||||
position_base,
|
||||
stride,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_horizontal_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
read_tile_line_with_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
position_base,
|
||||
stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_tile_line_with_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
i: UInt,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let vectorization_factor = Comptime::vectorization(tensor);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
let position = position_base + i * stride;
|
||||
|
||||
if tile_size == vectorization_factor {
|
||||
if col >= dim_horizontal {
|
||||
tile[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
} else {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
}
|
||||
} else {
|
||||
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
let mut num_loops = UInt::new(0);
|
||||
if dim_horizontal > col {
|
||||
let num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size));
|
||||
num_loops = num_reads / runtime_vectorization;
|
||||
}
|
||||
|
||||
for x in range(0u32, num_loops, Comptime::new(false)) {
|
||||
read_within_vector::<F>(
|
||||
tensor,
|
||||
tile_entry,
|
||||
position,
|
||||
x,
|
||||
vectorization_factor,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
tile[i] = tile_entry;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_tile_line_without_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
tile: &mut Array<F>,
|
||||
i: UInt,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let vectorization_factor = Comptime::vectorization(tensor);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
let position = position_base + i * stride;
|
||||
|
||||
if tile_size == vectorization_factor {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
} else {
|
||||
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
for j in range(
|
||||
0u32,
|
||||
Comptime::get(tile_size / vectorization_factor),
|
||||
unroll,
|
||||
) {
|
||||
read_within_vector::<F>(
|
||||
tensor,
|
||||
tile_entry,
|
||||
position,
|
||||
j,
|
||||
vectorization_factor,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
tile[i] = tile_entry;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Necessary when vectorization_factor < tile_size
|
||||
fn read_within_vector<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
mut tile_entry: F,
|
||||
position: UInt,
|
||||
i: UInt,
|
||||
vectorization_factor: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
tile_entry[i] = tensor[position + i];
|
||||
} else {
|
||||
let intermediate = tensor[position / runtime_vectorization + i];
|
||||
|
||||
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
|
||||
tile_entry[i * runtime_vectorization + j] = intermediate[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
/// Exported tests for loading to shared memory
|
||||
pub mod tests {
|
||||
|
@ -513,145 +141,6 @@ pub mod tests {
|
|||
|
||||
use super::{super::base::CoordinatesExpand, super::base::CubeTiling2dInfoExpand, *};
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_whole_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
bound_check_horizontal: Comptime<bool>,
|
||||
) {
|
||||
if Comptime::get(bound_check_horizontal) {
|
||||
read_with_horizontal_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tensor.shape(1),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
} else {
|
||||
read_without_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_partial_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
bound_check_horizontal: Comptime<bool>,
|
||||
) {
|
||||
if Comptime::get(bound_check_horizontal) {
|
||||
read_with_both_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(2),
|
||||
UInt::new(8),
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tensor.shape(0),
|
||||
tensor.shape(1),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
} else {
|
||||
read_with_vertical_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(2),
|
||||
UInt::new(8),
|
||||
tensor.stride(0),
|
||||
tensor.shape(0),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn load_tile_test<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
unit_row: UInt,
|
||||
unit_col: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let cube_offset = UInt::new(0);
|
||||
let check_vertical_bounds = Comptime::map(config, |c| c.check_m_bounds);
|
||||
let check_horizontal_bounds = Comptime::map(config, |c| c.check_k_bounds);
|
||||
let lhs_stride = lhs.stride(lhs.rank() - UInt::new(2));
|
||||
let dim_m = lhs.shape(lhs.rank() - UInt::new(2));
|
||||
let dim_k = lhs.shape(lhs.rank() - UInt::new(1));
|
||||
|
||||
load_tile::<F>(
|
||||
lhs,
|
||||
tile,
|
||||
cube_offset,
|
||||
unit_row,
|
||||
unit_col,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
lhs_stride,
|
||||
dim_m,
|
||||
dim_k,
|
||||
check_vertical_bounds,
|
||||
check_horizontal_bounds,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn write_tile_test<F: Float>(
|
||||
tile: &Array<F>,
|
||||
sm_out: &mut Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
transposed: Comptime<bool>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
|
||||
let sm_stride = block_size_m;
|
||||
let sm_size = Comptime::runtime(block_size_k * block_size_m);
|
||||
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
|
||||
|
||||
if Comptime::get(transposed) {
|
||||
write_tile_transposed(
|
||||
tile,
|
||||
shared_memory,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
Comptime::runtime(sm_stride),
|
||||
config,
|
||||
);
|
||||
} else {
|
||||
write_tile_plain(
|
||||
tile,
|
||||
shared_memory,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
Comptime::runtime(sm_stride),
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
for i in range(0u32, sm_size, Comptime::new(false)) {
|
||||
sm_out[i] = shared_memory[i];
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn load_tensor_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
|
@ -762,314 +251,6 @@ pub mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_vectorized_like_tile_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_vectorized_less_than_tile_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 2;
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_scalar_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 1;
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_scalar_out_of_bound_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 2;
|
||||
let tensor = range_tensor::<R>(4, 2, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
true,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 0.0, 0.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 6.0, 7.0, 0.0, 0.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_partial_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(4, 4, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_partial_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_tile_no_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(8, 8, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
load_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
0,
|
||||
0,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0, 16.0, 17.0, 18.0, 19.0, 24.0, 25.0, 26.0,
|
||||
27.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_tile_vertical_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(6, 8, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(6, 8, 8);
|
||||
|
||||
load_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
4,
|
||||
0,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_tile_horizontal_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(8, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 4, 8);
|
||||
|
||||
load_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
0,
|
||||
4,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn write_tile_plain_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tile = range_tensor::<R>(4, 4, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_tile_test_launch::<F32, R>(
|
||||
tile.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 4),
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0,
|
||||
9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn write_tile_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tile = range_tensor::<R>(4, 4, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_tile_test_launch::<F32, R>(
|
||||
tile.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 4.0, 8.0, 12.0, 0.0, 0.0, 0.0, 0.0, 1.0, 5.0, 9.0, 13.0, 0.0, 0.0, 0.0, 0.0, 2.0,
|
||||
6.0, 10.0, 14.0, 0.0, 0.0, 0.0, 0.0, 3.0, 7.0, 11.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn load_lhs_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let lhs = range_tensor::<R>(16, 16, device);
|
||||
|
|
|
@ -5,6 +5,8 @@ mod load_shared_memory;
|
|||
mod outer_product;
|
||||
#[cfg(feature = "export_tests")]
|
||||
mod test_utils;
|
||||
mod tile_read;
|
||||
mod tile_write;
|
||||
mod write_output;
|
||||
|
||||
pub use base::matmul_tiling_2d_cube;
|
||||
|
@ -13,5 +15,6 @@ pub use base::matmul_tiling_2d_cube;
|
|||
pub use {
|
||||
compute_loop::tests as compute_loop_tests,
|
||||
load_shared_memory::tests as load_shared_memory_tests,
|
||||
outer_product::tests as outer_product_tests, write_output::tests as write_output_tests,
|
||||
outer_product::tests as outer_product_tests, tile_read::tests as tile_read_tests,
|
||||
tile_write::tests as tile_write_tests, write_output::tests as write_output_tests,
|
||||
};
|
||||
|
|
|
@ -0,0 +1,672 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn read_tile_from_global_memory<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
cube_offset: UInt,
|
||||
read_row: UInt,
|
||||
read_col: UInt,
|
||||
skip_row: UInt,
|
||||
skip_col: UInt,
|
||||
tensor_stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
dim_horizontal: UInt,
|
||||
check_vertical_bounds: Comptime<bool>,
|
||||
check_horizontal_bounds: Comptime<bool>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
|
||||
let tensor_position_base = read_row * tensor_stride + read_col + cube_offset;
|
||||
|
||||
if Comptime::get(check_vertical_bounds) {
|
||||
let row = skip_row + read_row;
|
||||
|
||||
if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + read_col;
|
||||
read_with_both_checks::<F>(
|
||||
tensor,
|
||||
row,
|
||||
col,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_vertical,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
} else {
|
||||
read_with_vertical_checks::<F>(
|
||||
tensor,
|
||||
row,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_vertical,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
} else if Comptime::get(check_horizontal_bounds) {
|
||||
let col = skip_col + read_col;
|
||||
read_with_horizontal_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
} else {
|
||||
read_without_checks::<F>(
|
||||
tensor,
|
||||
tensor_position_base,
|
||||
tensor_stride,
|
||||
tile,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_both_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
row: UInt,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let mut num_reads = UInt::new(0);
|
||||
if dim_vertical > row {
|
||||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_with_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
position_base,
|
||||
stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_vertical_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
row: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_vertical: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let mut num_reads = UInt::new(0);
|
||||
if dim_vertical > row {
|
||||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_without_checks::<F>(
|
||||
tensor,
|
||||
position_base,
|
||||
stride,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_without_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
read_tile_line_without_checks::<F>(
|
||||
tensor,
|
||||
position_base,
|
||||
stride,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_with_horizontal_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
read_tile_line_with_checks::<F>(
|
||||
tensor,
|
||||
col,
|
||||
position_base,
|
||||
stride,
|
||||
dim_horizontal,
|
||||
tile,
|
||||
i,
|
||||
tile_size,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_tile_line_with_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
col: UInt,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
dim_horizontal: UInt,
|
||||
tile: &mut Array<F>,
|
||||
i: UInt,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let vectorization_factor = Comptime::vectorization(tensor);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
let position = position_base + i * stride;
|
||||
|
||||
if tile_size == vectorization_factor {
|
||||
if col >= dim_horizontal {
|
||||
tile[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
} else {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
}
|
||||
} else {
|
||||
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
let mut num_loops = UInt::new(0);
|
||||
if dim_horizontal > col {
|
||||
let num_reads = UInt::min(dim_horizontal - col, Comptime::runtime(tile_size));
|
||||
num_loops = num_reads / runtime_vectorization;
|
||||
}
|
||||
|
||||
for x in range(0u32, num_loops, Comptime::new(false)) {
|
||||
read_within_vector::<F>(
|
||||
tensor,
|
||||
tile_entry,
|
||||
position,
|
||||
x,
|
||||
vectorization_factor,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
tile[i] = tile_entry;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn read_tile_line_without_checks<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
position_base: UInt,
|
||||
stride: UInt,
|
||||
tile: &mut Array<F>,
|
||||
i: UInt,
|
||||
tile_size: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let vectorization_factor = Comptime::vectorization(tensor);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
let position = position_base + i * stride;
|
||||
|
||||
if tile_size == vectorization_factor {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
} else {
|
||||
let tile_entry = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
for j in range(
|
||||
0u32,
|
||||
Comptime::get(tile_size / vectorization_factor),
|
||||
unroll,
|
||||
) {
|
||||
read_within_vector::<F>(
|
||||
tensor,
|
||||
tile_entry,
|
||||
position,
|
||||
j,
|
||||
vectorization_factor,
|
||||
unroll,
|
||||
);
|
||||
}
|
||||
|
||||
tile[i] = tile_entry;
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
/// Necessary when vectorization_factor < tile_size
|
||||
fn read_within_vector<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
mut tile_entry: F,
|
||||
position: UInt,
|
||||
i: UInt,
|
||||
vectorization_factor: Comptime<UInt>,
|
||||
unroll: Comptime<bool>,
|
||||
) {
|
||||
let is_scalar = Comptime::map(vectorization_factor, |v| v.val == 1);
|
||||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
tile_entry[i] = tensor[position + i];
|
||||
} else {
|
||||
let intermediate = tensor[position / runtime_vectorization + i];
|
||||
|
||||
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
|
||||
tile_entry[i * runtime_vectorization + j] = intermediate[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
/// Exported tests for reading tiles in global memory
|
||||
pub mod tests {
|
||||
use crate::kernel::matmul::tiling2d_cube::test_utils::{
|
||||
assert_equals, create_empty, make_config, range_tensor, TILE_SIZE,
|
||||
};
|
||||
use crate::JitRuntime;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_whole_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
bound_check_horizontal: Comptime<bool>,
|
||||
) {
|
||||
if Comptime::get(bound_check_horizontal) {
|
||||
read_with_horizontal_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tensor.shape(1),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
} else {
|
||||
read_without_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn read_partial_test<F: Float>(
|
||||
tensor: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
tile_size: Comptime<UInt>,
|
||||
bound_check_horizontal: Comptime<bool>,
|
||||
) {
|
||||
if Comptime::get(bound_check_horizontal) {
|
||||
read_with_both_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(2),
|
||||
UInt::new(8),
|
||||
UInt::new(0),
|
||||
tensor.stride(0),
|
||||
tensor.shape(0),
|
||||
tensor.shape(1),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
} else {
|
||||
read_with_vertical_checks::<F>(
|
||||
tensor,
|
||||
UInt::new(2),
|
||||
UInt::new(8),
|
||||
tensor.stride(0),
|
||||
tensor.shape(0),
|
||||
tile,
|
||||
tile_size,
|
||||
Comptime::new(true),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn read_tile_test<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
tile: &mut Array<F>,
|
||||
unit_row: UInt,
|
||||
unit_col: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let cube_offset = UInt::new(0);
|
||||
let check_vertical_bounds = Comptime::map(config, |c| c.check_m_bounds);
|
||||
let check_horizontal_bounds = Comptime::map(config, |c| c.check_k_bounds);
|
||||
let lhs_stride = lhs.stride(lhs.rank() - UInt::new(2));
|
||||
let dim_m = lhs.shape(lhs.rank() - UInt::new(2));
|
||||
let dim_k = lhs.shape(lhs.rank() - UInt::new(1));
|
||||
|
||||
read_tile_from_global_memory::<F>(
|
||||
lhs,
|
||||
tile,
|
||||
cube_offset,
|
||||
unit_row,
|
||||
unit_col,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
lhs_stride,
|
||||
dim_m,
|
||||
dim_k,
|
||||
check_vertical_bounds,
|
||||
check_horizontal_bounds,
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_vectorized_like_tile_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_vectorized_less_than_tile_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 2;
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_scalar_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 1;
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0,
|
||||
15.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_whole_scalar_out_of_bound_test<R: JitRuntime>(device: &R::Device) {
|
||||
let vectorization_factor = 2;
|
||||
let tensor = range_tensor::<R>(4, 2, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_whole_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
vectorization_factor,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
true,
|
||||
);
|
||||
|
||||
assert_equals::<R>(
|
||||
tile,
|
||||
&[
|
||||
0.0, 1.0, 0.0, 0.0, 2.0, 3.0, 0.0, 0.0, 4.0, 5.0, 0.0, 0.0, 6.0, 7.0, 0.0, 0.0,
|
||||
],
|
||||
device,
|
||||
);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_partial_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(4, 4, device);
|
||||
let tile = create_empty::<R>(4, 4, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
read_partial_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
TILE_SIZE.into(),
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_tile_no_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(8, 8, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
read_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
0,
|
||||
0,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 1.0, 2.0, 3.0, 8.0, 9.0, 10.0, 11.0, 16.0, 17.0, 18.0, 19.0, 24.0, 25.0, 26.0,
|
||||
27.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_tile_vertical_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(6, 8, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(6, 8, 8);
|
||||
|
||||
read_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
4,
|
||||
0,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
32.0, 33.0, 34.0, 35.0, 40.0, 41.0, 42.0, 43.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn read_tile_horizontal_checks_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tensor = range_tensor::<R>(8, 4, device);
|
||||
let tile = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 4, 8);
|
||||
|
||||
read_tile_test_launch::<F32, R>(
|
||||
tensor.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
TensorArg::vectorized(
|
||||
TILE_SIZE as u8,
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile, 4),
|
||||
0,
|
||||
4,
|
||||
config,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(tile, expected, device);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,210 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn write_tile_plain<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
write_row: UInt,
|
||||
write_col: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
|
||||
let tile_size_runtime = Comptime::runtime(tile_size);
|
||||
|
||||
let sm_position_base = write_row * sm_stride + write_col;
|
||||
|
||||
if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / tile_size_runtime] = tile[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn write_tile_transposed<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
write_row: UInt,
|
||||
write_col: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let check_sm_bounds = Comptime::map(config, |c| c.check_sm_bounds);
|
||||
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
|
||||
|
||||
let sm_position_base = write_row * sm_stride + write_col;
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
shared_memory[sm_position_base] = tile[0];
|
||||
}
|
||||
} else {
|
||||
shared_memory[sm_position_base] = tile[0];
|
||||
}
|
||||
} else if Comptime::get(check_sm_bounds) {
|
||||
let sm_dim_vertical = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
if write_row < sm_dim_vertical {
|
||||
transpose_tile_to_shared_memory::<F>(
|
||||
tile,
|
||||
shared_memory,
|
||||
sm_position_base,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
transpose_tile_to_shared_memory::<F>(
|
||||
tile,
|
||||
shared_memory,
|
||||
sm_position_base,
|
||||
sm_stride,
|
||||
config,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn transpose_tile_to_shared_memory<F: Float>(
|
||||
tile: &Array<F>,
|
||||
mut shared_memory: SharedMemory<F>,
|
||||
sm_position_base: UInt,
|
||||
sm_stride: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let unroll = Comptime::map(config, |c| c.unroll_tile);
|
||||
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mut transposed = F::vectorized_empty(Comptime::get(tile_size));
|
||||
|
||||
// Unrolling this one makes the difference
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
transposed[j] = tile[j][i];
|
||||
}
|
||||
|
||||
let sm_position = (sm_position_base + i * sm_stride) / Comptime::runtime(tile_size);
|
||||
shared_memory[sm_position] = transposed;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
/// Exported tests for writing tiles to shared memory
|
||||
pub mod tests {
|
||||
use crate::kernel::matmul::tiling2d_cube::test_utils::{
|
||||
assert_equals, create_empty, make_config, range_tensor, TILE_SIZE,
|
||||
};
|
||||
use crate::JitRuntime;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[cube(launch)]
|
||||
fn write_tile_test<F: Float>(
|
||||
tile: &Array<F>,
|
||||
sm_out: &mut Array<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
transposed: Comptime<bool>,
|
||||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
|
||||
let sm_stride = block_size_m;
|
||||
let sm_size = Comptime::runtime(block_size_k * block_size_m);
|
||||
let shared_memory = SharedMemory::<F>::vectorized(sm_size, Comptime::get(tile_size));
|
||||
|
||||
if Comptime::get(transposed) {
|
||||
write_tile_transposed(
|
||||
tile,
|
||||
shared_memory,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
Comptime::runtime(sm_stride),
|
||||
config,
|
||||
);
|
||||
} else {
|
||||
write_tile_plain(
|
||||
tile,
|
||||
shared_memory,
|
||||
UInt::new(0),
|
||||
UInt::new(0),
|
||||
Comptime::runtime(sm_stride),
|
||||
config,
|
||||
);
|
||||
}
|
||||
|
||||
for i in range(0u32, sm_size, Comptime::new(false)) {
|
||||
sm_out[i] = shared_memory[i];
|
||||
}
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn write_tile_plain_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tile = range_tensor::<R>(4, 4, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_tile_test_launch::<F32, R>(
|
||||
tile.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 4),
|
||||
config,
|
||||
false,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0, 8.0,
|
||||
9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, 13.0, 14.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
|
||||
/// Exported test
|
||||
pub fn write_tile_transposed_unit_test<R: JitRuntime>(device: &R::Device) {
|
||||
let tile = range_tensor::<R>(4, 4, device);
|
||||
let sm_out = create_empty::<R>(8, 8, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_tile_test_launch::<F32, R>(
|
||||
tile.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &tile.handle, 4),
|
||||
ArrayArg::vectorized(TILE_SIZE as u8, &sm_out, 64),
|
||||
config,
|
||||
true,
|
||||
);
|
||||
|
||||
let expected = &[
|
||||
0.0, 4.0, 8.0, 12.0, 0.0, 0.0, 0.0, 0.0, 1.0, 5.0, 9.0, 13.0, 0.0, 0.0, 0.0, 0.0, 2.0,
|
||||
6.0, 10.0, 14.0, 0.0, 0.0, 0.0, 0.0, 3.0, 7.0, 11.0, 15.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
|
||||
];
|
||||
assert_equals::<R>(sm_out, expected, device);
|
||||
}
|
||||
}
|
|
@ -2,7 +2,8 @@
|
|||
mod tests {
|
||||
use super::*;
|
||||
use burn_jit::kernel::matmul::tiling2d_cube::{
|
||||
compute_loop_tests, load_shared_memory_tests, outer_product_tests, write_output_tests,
|
||||
compute_loop_tests, load_shared_memory_tests, outer_product_tests, tile_read_tests,
|
||||
tile_write_tests, write_output_tests,
|
||||
};
|
||||
use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig};
|
||||
use burn_tensor::{Shape, Tensor};
|
||||
|
@ -38,62 +39,54 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_read_whole_vectorized_like_tile_test() {
|
||||
load_shared_memory_tests::read_whole_vectorized_like_tile_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
tile_read_tests::read_whole_vectorized_like_tile_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_read_whole_vectorized_less_than_tile_test() {
|
||||
load_shared_memory_tests::read_whole_vectorized_less_than_tile_test::<TestRuntime>(
|
||||
tile_read_tests::read_whole_vectorized_less_than_tile_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_read_whole_scalar_test() {
|
||||
load_shared_memory_tests::read_whole_scalar_test::<TestRuntime>(&Default::default())
|
||||
tile_read_tests::read_whole_scalar_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn read_whole_scalar_out_of_bound_test() {
|
||||
load_shared_memory_tests::read_whole_scalar_out_of_bound_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
tile_read_tests::read_whole_scalar_out_of_bound_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_read_partial_unit_test() {
|
||||
load_shared_memory_tests::read_partial_unit_test::<TestRuntime>(&Default::default())
|
||||
tile_read_tests::read_partial_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_load_tile_no_checks_unit_test() {
|
||||
load_shared_memory_tests::load_tile_no_checks_unit_test::<TestRuntime>(&Default::default())
|
||||
tile_read_tests::read_tile_no_checks_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_load_tile_vertical_checks_unit_test() {
|
||||
load_shared_memory_tests::load_tile_vertical_checks_unit_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
tile_read_tests::read_tile_vertical_checks_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn tiling2d_matmul_load_tile_horizontal_checks_unit_test() {
|
||||
load_shared_memory_tests::load_tile_horizontal_checks_unit_test::<TestRuntime>(
|
||||
&Default::default(),
|
||||
)
|
||||
tile_read_tests::read_tile_horizontal_checks_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_tile_plain_unit_test() {
|
||||
load_shared_memory_tests::write_tile_plain_unit_test::<TestRuntime>(&Default::default())
|
||||
tile_write_tests::write_tile_plain_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn write_tile_transposed_unit_test() {
|
||||
load_shared_memory_tests::write_tile_transposed_unit_test::<TestRuntime>(&Default::default())
|
||||
tile_write_tests::write_tile_transposed_unit_test::<TestRuntime>(&Default::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
Loading…
Reference in New Issue