debugging wip

This commit is contained in:
louisfd 2024-06-18 21:48:53 -04:00
parent 53cb89bdcb
commit cebd4d4284
3 changed files with 183 additions and 107 deletions

View File

@ -143,6 +143,16 @@ pub(crate) fn codegen_call(
}
}
}
"zip" => {
let args = &call.args;
// Codegen
quote::quote! {
{
Comptime::zip_expand(#args)
}
}
}
"unwrap_or_else" => {
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);

View File

@ -35,6 +35,14 @@ impl<T> Comptime<T> {
pub fn map_expand<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
closure(inner)
}
pub fn zip<R, F: Fn(T, T) -> R>(_comptime: Self, _comptime2: Self, _closure: F) -> Comptime<R> {
unexpanded!()
}
pub fn zip_expand<R, F: Fn(T, T) -> R>(inner1: T, inner2: T, closure: F) -> R {
closure(inner1, inner2)
}
}
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {

View File

@ -36,7 +36,7 @@ impl CubeTiling2dConfig {
fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize) -> Self {
let tile_size = config.tile_size_m;
let sm_size_lhs = config.block_size_m * config.block_size_k * tile_size;
let sm_size_rhs = config.block_size_k * config.block_size_n * tile_size;
let sm_size_rhs = config.block_size_n * config.block_size_k * tile_size;
CubeTiling2dConfig {
block_size_m: UInt::new(config.block_size_m as u32),
@ -71,8 +71,8 @@ struct Tiling2dState<F: Float> {
pub unit_row: UInt,
pub shared_lhs: SharedMemory<F>,
pub shared_rhs: SharedMemory<F>,
pub register_m: Array<F>,
pub register_n: Array<F>,
pub register_m: F,
pub register_n: F,
pub results: Array<F>,
pub lhs_stride_col: UInt,
pub lhs_stride_row: UInt,
@ -135,7 +135,7 @@ fn gather_kernel_information<F: Float>(
let col = skip_col + unit_col;
// Batch offset for output
let offset_output = dim_m * dim_n * batch;
let offset_output = dim_m * dim_n * batch / Comptime::runtime(tile_size);
// Calculate offset for lhs and rhs, without regards to batches
let mut offset_lhs = skip_row * lhs_stride_row;
@ -148,9 +148,11 @@ fn gather_kernel_information<F: Float>(
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);
}
let register_m = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
let register_n = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
let results = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
offset_lhs /= Comptime::runtime(tile_size);
offset_rhs /= Comptime::runtime(tile_size);
let tile_squared = Comptime::zip(tile_size, tile_size, |c1, c2| UInt::new(c1.val * c2.val));
let results = Array::<F>::new(Comptime::get(tile_squared));
let shared_lhs =
SharedMemory::<F>::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size));
@ -167,8 +169,10 @@ fn gather_kernel_information<F: Float>(
n_loops = dim_k / Comptime::runtime(block_size_k);
}
// Dummy declaration
// Dummy declarations
let k = UInt::new(0);
let register_m = F::vectorized(0., Comptime::get(tile_size));
let register_n = F::vectorized(0., Comptime::get(tile_size));
Tiling2dState {
n_loops,
@ -205,7 +209,7 @@ fn load_shared_memory<F: Float>(
kernel_state: Tiling2dState<F>,
config: Comptime<CubeTiling2dConfig>,
) {
load_lhs_tensor(
load_lhs_tensor_plain(
kernel_state.lhs,
kernel_state.offset_lhs,
kernel_state.shared_lhs,
@ -215,11 +219,12 @@ fn load_shared_memory<F: Float>(
kernel_state.lhs_stride_row,
kernel_state.dim_m,
kernel_state.row,
kernel_state.col,
kernel_state.k,
kernel_state.dim_k,
config,
);
load_rhs_tensor(
load_rhs_tensor_transposed(
kernel_state.rhs,
kernel_state.offset_rhs,
kernel_state.shared_rhs,
@ -228,6 +233,7 @@ fn load_shared_memory<F: Float>(
kernel_state.rhs_stride_row,
kernel_state.rhs_stride_col,
kernel_state.dim_n,
kernel_state.row,
kernel_state.col,
kernel_state.k,
kernel_state.dim_k,
@ -237,7 +243,7 @@ fn load_shared_memory<F: Float>(
#[cube]
/// Assumes vectorization is in the same orientation we need in shared memory
fn load_lhs_tensor<F: Float>(
fn load_lhs_tensor_plain<F: Float>(
lhs: Tensor<F>,
offset_lhs: UInt,
mut shared_lhs: SharedMemory<F>,
@ -247,6 +253,7 @@ fn load_lhs_tensor<F: Float>(
lhs_stride_row: UInt,
dim_m: UInt,
row: UInt,
col: UInt,
k: UInt,
dim_k: UInt,
config: Comptime<CubeTiling2dConfig>,
@ -257,43 +264,74 @@ fn load_lhs_tensor<F: Float>(
let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds);
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
let position_in_lhs_base = (k + unit_col) * lhs_stride_col
+ unit_row * (lhs_stride_row / Comptime::runtime(tile_size))
+ offset_lhs;
let sm_position_base =
(unit_row * Comptime::runtime(block_size_k) + unit_col) / Comptime::runtime(tile_size);
let sm_stride = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
if Comptime::get(check_m_bounds) {
let n_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
for i in range(0u32, n_writes, Comptime::new(false)) {
let current_row = unit_row + i;
if Comptime::get(check_k_bounds) {
if k + unit_col >= dim_k {
// Shouldn't be partial vec, or it would not have accepted this vectorization factor
return;
if Comptime::get(check_k_bounds) {
if col >= dim_k {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let sm_position = sm_position_base + i * sm_stride;
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
}
} else {
let num_reads = UInt::min(dim_m - row, Comptime::runtime(tile_size));
for i in range(0u32, num_reads, Comptime::new(false)) {
let sm_position = sm_position_base + i * sm_stride;
let position_in_lhs =
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
shared_lhs[sm_position] = lhs[position_in_lhs];
}
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
let sm_position = sm_position_base + i * sm_stride;
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
}
}
// todo: is this runtime if mandatory?
if current_col < Comptime::runtime(block_size_k) {
// todo: runtime consts could be precomputed
let sm_position = current_row
* (Comptime::runtime(block_size_k) / Comptime::runtime(tile_size))
+ (unit_col / Comptime::runtime(tile_size));
// lhs_stride_col should be 1
let position_in_lhs = (k + unit_col) * lhs_stride_col
+ current_row * (lhs_stride_row / Comptime::runtime(tile_size))
+ offset_lhs;
} else {
let num_reads = UInt::min(dim_m - row, Comptime::runtime(tile_size));
for i in range(0u32, num_reads, Comptime::new(false)) {
let sm_position = sm_position_base + i * sm_stride;
let position_in_lhs =
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
shared_lhs[sm_position] = lhs[position_in_lhs];
}
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
let sm_position = sm_position_base + i * sm_stride;
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
// TODO Same for loop inner as above
if Comptime::get(check_k_bounds) {
if col >= dim_k {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let sm_position = sm_position_base + i * sm_stride;
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let sm_position = sm_position_base + i * sm_stride;
let position_in_lhs =
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
shared_lhs[sm_position] = lhs[position_in_lhs];
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let sm_position = sm_position_base + i * sm_stride;
let position_in_lhs =
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
shared_lhs[sm_position] = lhs[position_in_lhs];
}
}
}
}
#[cube]
fn load_rhs_tensor<F: Float>(
fn load_rhs_tensor_transposed<F: Float>(
rhs: Tensor<F>,
offset_rhs: UInt,
mut shared_rhs: SharedMemory<F>,
@ -302,14 +340,12 @@ fn load_rhs_tensor<F: Float>(
rhs_stride_row: UInt,
rhs_stride_col: UInt,
dim_n: UInt,
row: UInt,
col: UInt,
k: UInt,
dim_k: UInt,
config: Comptime<CubeTiling2dConfig>,
) {
// TODO :
// read n element-wise with stride, then store as vectorized-n in sm
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_n = Comptime::map(config, |c| c.block_size_n);
let unroll = Comptime::map(config, |c| c.unroll);
@ -317,49 +353,66 @@ fn load_rhs_tensor<F: Float>(
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds);
if Comptime::get(check_n_bounds) {
let n_writes_n = UInt::min(dim_n - col, Comptime::runtime(tile_size));
for j in range(0u32, n_writes, Comptime::new(false)) {
let n_writes_k = UInt::min(dim_k - row, Comptime::runtime(tile_size));
let mut array = Array::<F>::new(Comptime::get(tile_size));
let position_base_no_vec =
(k + current_row) * rhs_stride_row + unit_col * rhs_stride_col + offset_rhs;
let position_base = (k + unit_row) * rhs_stride_row + unit_col * rhs_stride_col + offset_rhs;
let sm_position_base =
(unit_col * Comptime::runtime(block_size_k) + unit_row) / Comptime::runtime(tile_size);
let sm_stride = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
for i in range(0u32, n_writes_k, Comptime::new(false)) {
let current_row = unit_row + i;
let current_col = unit_col + j;
// TODO is this necessary?
if current_row < Comptime::runtime(block_size_k) {
// todo: runtime consts could be precomputed
let sm_position = current_row
* (Comptime::runtime(block_size_n) / Comptime::runtime(tile_size))
+ (unit_col / Comptime::runtime(tile_size));
// Unvectorize
// TODO: Should increment second [] if stride_2 is 1
// Plus, other than 0s are unaccessible
array[i] = rhs[position_base + i * rhs_stride_col][j];
// Read entries
let mut entries = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
if Comptime::get(check_k_bounds) {
if Comptime::get(check_n_bounds) {
// We assume whole vectorization is out of bound
if col >= dim_n {
for i in range(0u32, Comptime::get(tile_size), unroll) {
entries[i] = F::vectorized(0., Comptime::get(tile_size));
}
// Pad with zeros
if Comptime::get(check_k_bounds) {
for i in range(n_writes, Comptime::get(tile_size), Comptime::new(false)) {
array[i] = F::new(0.);
}
} else {
let num_reads = UInt::min(dim_k - row, Comptime::runtime(tile_size));
for i in range(0u32, num_reads, Comptime::new(false)) {
entries[i] = rhs[position_base + i * rhs_stride_row];
}
// TODO could tile_size be fetched from array length?
// TODO make sure what we write works with what is now read in computation loop
shared_rhs[sm_position] = array.to_vectorized(tile_size);
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
entries[i] = F::vectorized(0., Comptime::get(tile_size));
}
}
} else {
let num_reads = UInt::min(dim_k - row, Comptime::runtime(tile_size));
for i in range(0u32, num_reads, Comptime::new(false)) {
entries[i] = rhs[position_base + i * rhs_stride_row];
}
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
entries[i] = F::vectorized(0., Comptime::get(tile_size));
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
for j in range(0u32, Comptime::get(tile_size), unroll) {
// TODO Same for loop inner as above
if Comptime::get(check_n_bounds) {
// We assume whole vectorization is out of bound
if col >= dim_n {
for i in range(0u32, Comptime::get(tile_size), unroll) {
entries[i] = F::vectorized(0., Comptime::get(tile_size));
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
entries[i] = rhs[position_base + i * rhs_stride_row];
}
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
entries[i] = rhs[position_base + i * rhs_stride_row];
}
}
}
// Decompose vectorization then recompose as transposed
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = Array::<F>::new(Comptime::get(tile_size));
for j in range(0u32, Comptime::get(tile_size), unroll) {
transposed[j] = entries[j][i];
}
let sm_position = sm_position_base + i * sm_stride;
shared_rhs[sm_position] = transposed.to_vectorized(tile_size);
}
}
#[cube]
@ -379,26 +432,25 @@ fn computation_loop<F: Float>(
let unroll = Comptime::map(config, |c| c.unroll);
let tile_size = Comptime::vectorization(kernel_state.lhs);
for dot_index in range(0u32, Comptime::get(block_size_k), unroll) {
let lhs_pos =
unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k) + dot_index;
let rhs_pos =
(dot_index * Comptime::runtime(block_size_n) + unit_col) / Comptime::runtime(tile_size);
// TODO this would greatly beneficiate from comptime arithmetic so we could unroll
let num_compute = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
let lhs_pos_base = unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k);
let rhs_pos_base = unit_col / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k);
// Get a tile
for i in range(0u32, Comptime::get(tile_size), unroll) {
let WHAT = UInt::new(0); // TODO of course
register_m[i] = shared_lhs[lhs_pos + i * WHAT];
register_n[i] = shared_rhs[rhs_pos + i * WHAT];
}
for dot_index in range(0u32, num_compute, Comptime::new(false)) {
let dot_index = Comptime::runtime(tile_size) * dot_index;
let lhs_pos = lhs_pos_base + dot_index;
let rhs_pos = rhs_pos_base + dot_index;
// Replaceable with tensor core call
register_m = shared_lhs[lhs_pos];
register_n = shared_rhs[rhs_pos];
// Naive version that decomposes vectorization
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
let row = register_m[res_idx_m];
let pos_m = res_idx_m * Comptime::runtime(tile_size);
let res_pos_base = res_idx_m * Comptime::runtime(tile_size);
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
let col = register_n[res_idx_n];
results[pos_m + res_idx_n] += row * col;
let mul = register_m[res_idx_m] * register_n[res_idx_n];
results[res_pos_base + res_idx_n] += mul;
}
}
}
@ -422,23 +474,28 @@ fn write_to_output<F: Float>(
let tile_size = Comptime::vectorization(kernel_state.lhs);
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
let row_index = row + res_idx_m * out_stride_row;
let col_index = col; // Not sure
let results_pos_m = res_idx_m * Comptime::runtime(tile_size); // Times tile_size, or no because vectorized?
let results_pos_m = res_idx_m * Comptime::runtime(tile_size);
if Comptime::get(check_m_bounds) {
// // TODO: Not sure if necessary. SM already padded if overflowing
// if Comptime::get(check_n_bounds) {
// within_output = within_output && col_index < kernel_state.dim_n;
// }
if row_index < kernel_state.dim_m {
// Warning: can't do the following:
// let mut out = kernel_state.out;
kernel_state.out[row_index + col_index] = results[results_pos_m];
}
} else {
kernel_state.out[row_index + col_index + offset_output] = results[results_pos_m];
// TODO just reinterpret the array if possible
let mut array = Array::<F>::new(Comptime::get(tile_size));
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
array[res_idx_n] = results[results_pos_m + res_idx_n];
}
let row_index = (row + res_idx_m) * out_stride_row;
let col_index = col * out_stride_col;
// FOR DEBUGGING
// TODO: it's a pain to put a debug value in output if it's vectorized
let print_value = out_stride_row;
let mut out = Array::<F>::new(Comptime::get(tile_size));
// for i in range(0u32, Comptime::get(tile_size), unroll) {
out[3] = F::cast_from(print_value) + F::new(10.);
// }
kernel_state.out[row_index + col_index + offset_output] = out.to_vectorized(tile_size);
// kernel_state.out[res_idx_m] = out.to_vectorized(tile_size);
// F::vectorized(2., Comptime::get(tile_size));
}
}
@ -493,10 +550,11 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
let vectorization_factor = 4;
let settings = KernelSettings::default()
.vectorize_input(0, 4)
.vectorize_input(1, 4)
.vectorize_output(0, 4);
.vectorize_input(0, vectorization_factor)
.vectorize_input(1, vectorization_factor)
.vectorize_output(0, vectorization_factor);
tiling2d_matmul_kernel_launch::<E::CubeElement, R>(
client,