mirror of https://github.com/tracel-ai/burn.git
debugging wip
This commit is contained in:
parent
53cb89bdcb
commit
cebd4d4284
|
@ -143,6 +143,16 @@ pub(crate) fn codegen_call(
|
|||
}
|
||||
}
|
||||
}
|
||||
"zip" => {
|
||||
let args = &call.args;
|
||||
|
||||
// Codegen
|
||||
quote::quote! {
|
||||
{
|
||||
Comptime::zip_expand(#args)
|
||||
}
|
||||
}
|
||||
}
|
||||
"unwrap_or_else" => {
|
||||
let (expansion, variables) = codegen_args(&call.args, loop_level, variable_tracker);
|
||||
|
||||
|
|
|
@ -35,6 +35,14 @@ impl<T> Comptime<T> {
|
|||
pub fn map_expand<R, F: Fn(T) -> R>(inner: T, closure: F) -> R {
|
||||
closure(inner)
|
||||
}
|
||||
|
||||
pub fn zip<R, F: Fn(T, T) -> R>(_comptime: Self, _comptime2: Self, _closure: F) -> Comptime<R> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn zip_expand<R, F: Fn(T, T) -> R>(inner1: T, inner2: T, closure: F) -> R {
|
||||
closure(inner1, inner2)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubeType + Into<T::ExpandType>> Comptime<Option<T>> {
|
||||
|
|
|
@ -36,7 +36,7 @@ impl CubeTiling2dConfig {
|
|||
fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize) -> Self {
|
||||
let tile_size = config.tile_size_m;
|
||||
let sm_size_lhs = config.block_size_m * config.block_size_k * tile_size;
|
||||
let sm_size_rhs = config.block_size_k * config.block_size_n * tile_size;
|
||||
let sm_size_rhs = config.block_size_n * config.block_size_k * tile_size;
|
||||
|
||||
CubeTiling2dConfig {
|
||||
block_size_m: UInt::new(config.block_size_m as u32),
|
||||
|
@ -71,8 +71,8 @@ struct Tiling2dState<F: Float> {
|
|||
pub unit_row: UInt,
|
||||
pub shared_lhs: SharedMemory<F>,
|
||||
pub shared_rhs: SharedMemory<F>,
|
||||
pub register_m: Array<F>,
|
||||
pub register_n: Array<F>,
|
||||
pub register_m: F,
|
||||
pub register_n: F,
|
||||
pub results: Array<F>,
|
||||
pub lhs_stride_col: UInt,
|
||||
pub lhs_stride_row: UInt,
|
||||
|
@ -135,7 +135,7 @@ fn gather_kernel_information<F: Float>(
|
|||
let col = skip_col + unit_col;
|
||||
|
||||
// Batch offset for output
|
||||
let offset_output = dim_m * dim_n * batch;
|
||||
let offset_output = dim_m * dim_n * batch / Comptime::runtime(tile_size);
|
||||
|
||||
// Calculate offset for lhs and rhs, without regards to batches
|
||||
let mut offset_lhs = skip_row * lhs_stride_row;
|
||||
|
@ -148,9 +148,11 @@ fn gather_kernel_information<F: Float>(
|
|||
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);
|
||||
}
|
||||
|
||||
let register_m = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
let register_n = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
let results = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
offset_lhs /= Comptime::runtime(tile_size);
|
||||
offset_rhs /= Comptime::runtime(tile_size);
|
||||
|
||||
let tile_squared = Comptime::zip(tile_size, tile_size, |c1, c2| UInt::new(c1.val * c2.val));
|
||||
let results = Array::<F>::new(Comptime::get(tile_squared));
|
||||
|
||||
let shared_lhs =
|
||||
SharedMemory::<F>::vectorized(Comptime::get(sm_size_lhs), Comptime::get(tile_size));
|
||||
|
@ -167,8 +169,10 @@ fn gather_kernel_information<F: Float>(
|
|||
n_loops = dim_k / Comptime::runtime(block_size_k);
|
||||
}
|
||||
|
||||
// Dummy declaration
|
||||
// Dummy declarations
|
||||
let k = UInt::new(0);
|
||||
let register_m = F::vectorized(0., Comptime::get(tile_size));
|
||||
let register_n = F::vectorized(0., Comptime::get(tile_size));
|
||||
|
||||
Tiling2dState {
|
||||
n_loops,
|
||||
|
@ -205,7 +209,7 @@ fn load_shared_memory<F: Float>(
|
|||
kernel_state: Tiling2dState<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
load_lhs_tensor(
|
||||
load_lhs_tensor_plain(
|
||||
kernel_state.lhs,
|
||||
kernel_state.offset_lhs,
|
||||
kernel_state.shared_lhs,
|
||||
|
@ -215,11 +219,12 @@ fn load_shared_memory<F: Float>(
|
|||
kernel_state.lhs_stride_row,
|
||||
kernel_state.dim_m,
|
||||
kernel_state.row,
|
||||
kernel_state.col,
|
||||
kernel_state.k,
|
||||
kernel_state.dim_k,
|
||||
config,
|
||||
);
|
||||
load_rhs_tensor(
|
||||
load_rhs_tensor_transposed(
|
||||
kernel_state.rhs,
|
||||
kernel_state.offset_rhs,
|
||||
kernel_state.shared_rhs,
|
||||
|
@ -228,6 +233,7 @@ fn load_shared_memory<F: Float>(
|
|||
kernel_state.rhs_stride_row,
|
||||
kernel_state.rhs_stride_col,
|
||||
kernel_state.dim_n,
|
||||
kernel_state.row,
|
||||
kernel_state.col,
|
||||
kernel_state.k,
|
||||
kernel_state.dim_k,
|
||||
|
@ -237,7 +243,7 @@ fn load_shared_memory<F: Float>(
|
|||
|
||||
#[cube]
|
||||
/// Assumes vectorization is in the same orientation we need in shared memory
|
||||
fn load_lhs_tensor<F: Float>(
|
||||
fn load_lhs_tensor_plain<F: Float>(
|
||||
lhs: Tensor<F>,
|
||||
offset_lhs: UInt,
|
||||
mut shared_lhs: SharedMemory<F>,
|
||||
|
@ -247,6 +253,7 @@ fn load_lhs_tensor<F: Float>(
|
|||
lhs_stride_row: UInt,
|
||||
dim_m: UInt,
|
||||
row: UInt,
|
||||
col: UInt,
|
||||
k: UInt,
|
||||
dim_k: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
|
@ -257,43 +264,74 @@ fn load_lhs_tensor<F: Float>(
|
|||
let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds);
|
||||
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
|
||||
|
||||
let position_in_lhs_base = (k + unit_col) * lhs_stride_col
|
||||
+ unit_row * (lhs_stride_row / Comptime::runtime(tile_size))
|
||||
+ offset_lhs;
|
||||
let sm_position_base =
|
||||
(unit_row * Comptime::runtime(block_size_k) + unit_col) / Comptime::runtime(tile_size);
|
||||
let sm_stride = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
|
||||
|
||||
if Comptime::get(check_m_bounds) {
|
||||
let n_writes = UInt::min(dim_m - row, Comptime::runtime(tile_size));
|
||||
|
||||
for i in range(0u32, n_writes, Comptime::new(false)) {
|
||||
let current_row = unit_row + i;
|
||||
|
||||
if Comptime::get(check_k_bounds) {
|
||||
if k + unit_col >= dim_k {
|
||||
// Shouldn't be partial vec, or it would not have accepted this vectorization factor
|
||||
return;
|
||||
if Comptime::get(check_k_bounds) {
|
||||
if col >= dim_k {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
} else {
|
||||
let num_reads = UInt::min(dim_m - row, Comptime::runtime(tile_size));
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
let position_in_lhs =
|
||||
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
|
||||
shared_lhs[sm_position] = lhs[position_in_lhs];
|
||||
}
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
|
||||
// todo: is this runtime if mandatory?
|
||||
if current_col < Comptime::runtime(block_size_k) {
|
||||
// todo: runtime consts could be precomputed
|
||||
let sm_position = current_row
|
||||
* (Comptime::runtime(block_size_k) / Comptime::runtime(tile_size))
|
||||
+ (unit_col / Comptime::runtime(tile_size));
|
||||
|
||||
// lhs_stride_col should be 1
|
||||
let position_in_lhs = (k + unit_col) * lhs_stride_col
|
||||
+ current_row * (lhs_stride_row / Comptime::runtime(tile_size))
|
||||
+ offset_lhs;
|
||||
|
||||
} else {
|
||||
let num_reads = UInt::min(dim_m - row, Comptime::runtime(tile_size));
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
let position_in_lhs =
|
||||
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
|
||||
shared_lhs[sm_position] = lhs[position_in_lhs];
|
||||
}
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
// TODO Same for loop inner as above
|
||||
if Comptime::get(check_k_bounds) {
|
||||
if col >= dim_k {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
shared_lhs[sm_position] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
let position_in_lhs =
|
||||
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
|
||||
shared_lhs[sm_position] = lhs[position_in_lhs];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
let position_in_lhs =
|
||||
position_in_lhs_base + i * (lhs_stride_row / Comptime::runtime(tile_size));
|
||||
shared_lhs[sm_position] = lhs[position_in_lhs];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn load_rhs_tensor<F: Float>(
|
||||
fn load_rhs_tensor_transposed<F: Float>(
|
||||
rhs: Tensor<F>,
|
||||
offset_rhs: UInt,
|
||||
mut shared_rhs: SharedMemory<F>,
|
||||
|
@ -302,14 +340,12 @@ fn load_rhs_tensor<F: Float>(
|
|||
rhs_stride_row: UInt,
|
||||
rhs_stride_col: UInt,
|
||||
dim_n: UInt,
|
||||
row: UInt,
|
||||
col: UInt,
|
||||
k: UInt,
|
||||
dim_k: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
// TODO :
|
||||
// read n element-wise with stride, then store as vectorized-n in sm
|
||||
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
|
@ -317,49 +353,66 @@ fn load_rhs_tensor<F: Float>(
|
|||
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
|
||||
let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds);
|
||||
|
||||
if Comptime::get(check_n_bounds) {
|
||||
let n_writes_n = UInt::min(dim_n - col, Comptime::runtime(tile_size));
|
||||
for j in range(0u32, n_writes, Comptime::new(false)) {
|
||||
let n_writes_k = UInt::min(dim_k - row, Comptime::runtime(tile_size));
|
||||
let mut array = Array::<F>::new(Comptime::get(tile_size));
|
||||
let position_base_no_vec =
|
||||
(k + current_row) * rhs_stride_row + unit_col * rhs_stride_col + offset_rhs;
|
||||
let position_base = (k + unit_row) * rhs_stride_row + unit_col * rhs_stride_col + offset_rhs;
|
||||
let sm_position_base =
|
||||
(unit_col * Comptime::runtime(block_size_k) + unit_row) / Comptime::runtime(tile_size);
|
||||
let sm_stride = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
|
||||
|
||||
for i in range(0u32, n_writes_k, Comptime::new(false)) {
|
||||
let current_row = unit_row + i;
|
||||
let current_col = unit_col + j;
|
||||
|
||||
// TODO is this necessary?
|
||||
if current_row < Comptime::runtime(block_size_k) {
|
||||
// todo: runtime consts could be precomputed
|
||||
let sm_position = current_row
|
||||
* (Comptime::runtime(block_size_n) / Comptime::runtime(tile_size))
|
||||
+ (unit_col / Comptime::runtime(tile_size));
|
||||
|
||||
// Unvectorize
|
||||
// TODO: Should increment second [] if stride_2 is 1
|
||||
// Plus, other than 0s are unaccessible
|
||||
array[i] = rhs[position_base + i * rhs_stride_col][j];
|
||||
// Read entries
|
||||
let mut entries = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
if Comptime::get(check_k_bounds) {
|
||||
if Comptime::get(check_n_bounds) {
|
||||
// We assume whole vectorization is out of bound
|
||||
if col >= dim_n {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
entries[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
// Pad with zeros
|
||||
if Comptime::get(check_k_bounds) {
|
||||
for i in range(n_writes, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
array[i] = F::new(0.);
|
||||
}
|
||||
} else {
|
||||
let num_reads = UInt::min(dim_k - row, Comptime::runtime(tile_size));
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
entries[i] = rhs[position_base + i * rhs_stride_row];
|
||||
}
|
||||
|
||||
// TODO could tile_size be fetched from array length?
|
||||
// TODO make sure what we write works with what is now read in computation loop
|
||||
shared_rhs[sm_position] = array.to_vectorized(tile_size);
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
entries[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let num_reads = UInt::min(dim_k - row, Comptime::runtime(tile_size));
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
entries[i] = rhs[position_base + i * rhs_stride_row];
|
||||
}
|
||||
for i in range(num_reads, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
entries[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
// TODO Same for loop inner as above
|
||||
if Comptime::get(check_n_bounds) {
|
||||
// We assume whole vectorization is out of bound
|
||||
if col >= dim_n {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
entries[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
entries[i] = rhs[position_base + i * rhs_stride_row];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
entries[i] = rhs[position_base + i * rhs_stride_row];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decompose vectorization then recompose as transposed
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mut transposed = Array::<F>::new(Comptime::get(tile_size));
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
transposed[j] = entries[j][i];
|
||||
}
|
||||
let sm_position = sm_position_base + i * sm_stride;
|
||||
shared_rhs[sm_position] = transposed.to_vectorized(tile_size);
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
|
@ -379,26 +432,25 @@ fn computation_loop<F: Float>(
|
|||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let tile_size = Comptime::vectorization(kernel_state.lhs);
|
||||
|
||||
for dot_index in range(0u32, Comptime::get(block_size_k), unroll) {
|
||||
let lhs_pos =
|
||||
unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k) + dot_index;
|
||||
let rhs_pos =
|
||||
(dot_index * Comptime::runtime(block_size_n) + unit_col) / Comptime::runtime(tile_size);
|
||||
// TODO this would greatly beneficiate from comptime arithmetic so we could unroll
|
||||
let num_compute = Comptime::runtime(block_size_k) / Comptime::runtime(tile_size);
|
||||
let lhs_pos_base = unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k);
|
||||
let rhs_pos_base = unit_col / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k);
|
||||
|
||||
// Get a tile
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let WHAT = UInt::new(0); // TODO of course
|
||||
register_m[i] = shared_lhs[lhs_pos + i * WHAT];
|
||||
register_n[i] = shared_rhs[rhs_pos + i * WHAT];
|
||||
}
|
||||
for dot_index in range(0u32, num_compute, Comptime::new(false)) {
|
||||
let dot_index = Comptime::runtime(tile_size) * dot_index;
|
||||
let lhs_pos = lhs_pos_base + dot_index;
|
||||
let rhs_pos = rhs_pos_base + dot_index;
|
||||
|
||||
// Replaceable with tensor core call
|
||||
register_m = shared_lhs[lhs_pos];
|
||||
register_n = shared_rhs[rhs_pos];
|
||||
|
||||
// Naive version that decomposes vectorization
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let row = register_m[res_idx_m];
|
||||
let pos_m = res_idx_m * Comptime::runtime(tile_size);
|
||||
let res_pos_base = res_idx_m * Comptime::runtime(tile_size);
|
||||
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let col = register_n[res_idx_n];
|
||||
results[pos_m + res_idx_n] += row * col;
|
||||
let mul = register_m[res_idx_m] * register_n[res_idx_n];
|
||||
results[res_pos_base + res_idx_n] += mul;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -422,23 +474,28 @@ fn write_to_output<F: Float>(
|
|||
let tile_size = Comptime::vectorization(kernel_state.lhs);
|
||||
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let row_index = row + res_idx_m * out_stride_row;
|
||||
let col_index = col; // Not sure
|
||||
let results_pos_m = res_idx_m * Comptime::runtime(tile_size); // Times tile_size, or no because vectorized?
|
||||
let results_pos_m = res_idx_m * Comptime::runtime(tile_size);
|
||||
|
||||
if Comptime::get(check_m_bounds) {
|
||||
// // TODO: Not sure if necessary. SM already padded if overflowing
|
||||
// if Comptime::get(check_n_bounds) {
|
||||
// within_output = within_output && col_index < kernel_state.dim_n;
|
||||
// }
|
||||
if row_index < kernel_state.dim_m {
|
||||
// Warning: can't do the following:
|
||||
// let mut out = kernel_state.out;
|
||||
kernel_state.out[row_index + col_index] = results[results_pos_m];
|
||||
}
|
||||
} else {
|
||||
kernel_state.out[row_index + col_index + offset_output] = results[results_pos_m];
|
||||
// TODO just reinterpret the array if possible
|
||||
let mut array = Array::<F>::new(Comptime::get(tile_size));
|
||||
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
array[res_idx_n] = results[results_pos_m + res_idx_n];
|
||||
}
|
||||
|
||||
let row_index = (row + res_idx_m) * out_stride_row;
|
||||
let col_index = col * out_stride_col;
|
||||
|
||||
// FOR DEBUGGING
|
||||
// TODO: it's a pain to put a debug value in output if it's vectorized
|
||||
let print_value = out_stride_row;
|
||||
let mut out = Array::<F>::new(Comptime::get(tile_size));
|
||||
// for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
out[3] = F::cast_from(print_value) + F::new(10.);
|
||||
// }
|
||||
|
||||
kernel_state.out[row_index + col_index + offset_output] = out.to_vectorized(tile_size);
|
||||
// kernel_state.out[res_idx_m] = out.to_vectorized(tile_size);
|
||||
// F::vectorized(2., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -493,10 +550,11 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
|
||||
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
|
||||
|
||||
let vectorization_factor = 4;
|
||||
let settings = KernelSettings::default()
|
||||
.vectorize_input(0, 4)
|
||||
.vectorize_input(1, 4)
|
||||
.vectorize_output(0, 4);
|
||||
.vectorize_input(0, vectorization_factor)
|
||||
.vectorize_input(1, vectorization_factor)
|
||||
.vectorize_output(0, vectorization_factor);
|
||||
|
||||
tiling2d_matmul_kernel_launch::<E::CubeElement, R>(
|
||||
client,
|
||||
|
|
Loading…
Reference in New Issue