runs, but oh lord is it wrong

This commit is contained in:
louisfd 2024-06-17 15:00:17 -04:00
parent 43cf06da18
commit 12843ed159
3 changed files with 110 additions and 31 deletions

View File

@ -1,5 +1,8 @@
use crate::frontend::{CubeElem, CubeType, ExpandElement};
use crate::ir::Elem;
use crate::ir::{Elem, Operator};
use crate::prelude::{init_expand, CubeContext};
use super::Init;
impl CubeType for bool {
type ExpandType = ExpandElement;
@ -10,3 +13,9 @@ impl CubeElem for bool {
Elem::Bool
}
}
// impl Init for bool {
// fn init(self, context: &mut CubeContext) -> Self {
// init_expand(context, self, Operator::Assign)
// }
// }

View File

@ -1,7 +1,7 @@
use crate::{fusion::kernel, kernel::into_contiguous, tensor::JitTensor, FloatElement, JitRuntime};
use burn_cube::prelude::*;
use super::{tiling2d_launch_options, tiling2d_shader::write_to_output, Tiling2dConfig};
use super::{tiling2d_launch_options, Tiling2dConfig};
impl Init for CubeTiling2dConfig {
fn init(self, _context: &mut CubeContext) -> Self {
@ -20,8 +20,12 @@ pub struct CubeTiling2dConfig {
pub block_size_n: UInt,
/// Loop unrolling
pub unroll: bool,
/// Bounds must be checked on lhs dimension
pub check_m_bounds: bool,
/// Bounds must be checked on common dimension
pub check_k_bounds: bool,
/// Bounds must be checked on rhs dimension
pub check_n_bounds: bool,
/// Shared memory size lhs: technically derivable from others, but needs comptime arithmetic
pub sm_size_lhs: UInt,
/// Shared memory size rhs: technically derivable from others, but needs comptime arithmetic
@ -29,7 +33,7 @@ pub struct CubeTiling2dConfig {
}
impl CubeTiling2dConfig {
fn new(config: Tiling2dConfig, k: usize) -> Self {
fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize) -> Self {
let tile_size = config.tile_size_m;
let sm_size_lhs = config.block_size_m * config.block_size_k * tile_size;
let sm_size_rhs = config.block_size_k * config.block_size_n * tile_size;
@ -39,14 +43,16 @@ impl CubeTiling2dConfig {
block_size_k: UInt::new(config.block_size_k as u32),
block_size_n: UInt::new(config.block_size_n as u32),
unroll: config.unroll,
check_m_bounds: m % config.block_size_m != 0,
check_k_bounds: k % config.block_size_k != 0,
check_n_bounds: n % config.block_size_n != 0,
sm_size_lhs: UInt::new(sm_size_lhs as u32),
sm_size_rhs: UInt::new(sm_size_rhs as u32),
}
}
}
#[derive(CubeType)]
#[derive(CubeType, Copy, Clone)]
struct Tiling2dState<F: Float> {
pub n_loops: UInt,
pub k: UInt,
@ -65,8 +71,8 @@ struct Tiling2dState<F: Float> {
pub unit_row: UInt,
pub shared_lhs: SharedMemory<F>,
pub shared_rhs: SharedMemory<F>,
pub register_m: F,
pub register_n: F,
pub register_m: Array<F>,
pub register_n: Array<F>,
pub results: Array<F>,
pub lhs_stride_col: UInt,
pub lhs_stride_row: UInt,
@ -136,12 +142,14 @@ fn gather_kernel_information<F: Float>(
let mut offset_rhs = skip_col * rhs_stride_col;
// Batch offset for lhs, rhs
for b in range(0, rank - UInt::new(2), unroll) {
for b in range(0u32, rank - UInt::new(2), unroll) {
let tmp = offset_output / out.stride(b);
offset_lhs += tmp % lhs.shape(b) * lhs.stride(b);
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);
}
let register_m = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
let register_n = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
let results = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
let shared_lhs =
@ -159,10 +167,8 @@ fn gather_kernel_information<F: Float>(
n_loops = dim_k / Comptime::runtime(block_size_k);
}
// Dummy declarations
// Dummy declaration
let k = UInt::new(0);
let register_m = F::new(0.);
let register_n = F::new(0.);
Tiling2dState {
n_loops,
@ -253,14 +259,17 @@ fn load_tensor_with_checks<F: Float>(
let tile_size = Comptime::vectorization(input);
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
let mut n_writes = Comptime::runtime(tile_size);
// TODO: direct assignation from Comptime::runtime gives a constant
let n_writes_tmp = Comptime::runtime(tile_size);
let mut n_writes = n_writes_tmp;
if Comptime::get(check_k_bounds) {
n_writes = UInt::min(dim - pos_in_dim, Comptime::runtime(tile_size));
}
// TODO we should avoid that if in no_check_bound version
if n_writes >= UInt::new(1) {
for j in range(0, Comptime::get(tile_size), unroll) {
for j in range(0u32, Comptime::get(tile_size), unroll) {
let current = unit_idx_1 + j;
if current + k < dim_k {
@ -281,8 +290,11 @@ fn load_tensor_with_checks<F: Float>(
// TODO simplify when stride_2 is 1, so we can leverage already vectorized
let mut array = Array::<F>::new(Comptime::get(tile_size));
for i in range(0, n_writes, Comptime::new(false)) {
array[i] = input[position_base + i * stride_2];
for i in range(0u32, n_writes, Comptime::new(false)) {
// Unvectorize
// TODO: Should increment second [] if stride_2 is 1
// Plus, other than 0s are unaccessible
array[i] = input[position_base + i * stride_2][UInt::new(0)];
}
// Pad with zeros
if Comptime::get(check_k_bounds) {
@ -292,6 +304,7 @@ fn load_tensor_with_checks<F: Float>(
}
// TODO could tile_size be fetched from array length?
// TODO make sure what we write works with what is now read in computation loop
shared_memory[sm_position] = array.to_vectorized(tile_size);
}
}
@ -308,21 +321,77 @@ fn computation_loop<F: Float>(
let shared_lhs = kernel_state.shared_lhs;
let shared_rhs = kernel_state.shared_rhs;
let mut register_m = kernel_state.register_m;
let register_n = kernel_state.register_n;
let results = kernel_state.results;
let mut register_n = kernel_state.register_n;
let mut results = kernel_state.results;
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_n = Comptime::map(config, |c| c.block_size_n);
let unroll = Comptime::map(config, |c| c.unroll);
let vectorization = Comptime::vectorization(kernel_state.lhs);
let tile_size = Comptime::vectorization(kernel_state.lhs);
for dot_index in range(0, Comptime::get(block_size_k), unroll) {
register_m = shared_lhs[unit_row / Comptime::runtime(vectorization)
* Comptime::runtime(block_size_k)
+ dot_index];
for dot_index in range(0u32, Comptime::get(block_size_k), unroll) {
let lhs_pos =
unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k) + dot_index;
let rhs_pos =
(dot_index * Comptime::runtime(block_size_n) + unit_col) / Comptime::runtime(tile_size);
// Get a tile
for i in range(0u32, Comptime::get(tile_size), unroll) {
let WHAT = UInt::new(0); // TODO of course
register_m[i] = shared_lhs[lhs_pos + i * WHAT];
register_n[i] = shared_rhs[rhs_pos + i * WHAT];
}
// Replaceable with tensor core call
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
let row = register_m[res_idx_m];
let pos_m = res_idx_m * Comptime::runtime(tile_size);
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
let col = register_n[res_idx_n];
results[pos_m + res_idx_n] += row * col;
}
}
}
}
#[cube]
fn write_to_output<F: Float>(
mut kernel_state: Tiling2dState<F>,
config: Comptime<CubeTiling2dConfig>,
) {
// No bounds check version
let row = kernel_state.row;
let col = kernel_state.col;
let out_stride_row = kernel_state.out_stride_row;
let out_stride_col = kernel_state.out_stride_col;
let results = kernel_state.results;
let unroll = Comptime::map(config, |c| c.unroll);
let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds);
let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds);
let tile_size = Comptime::vectorization(kernel_state.lhs);
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
let row_index = row + res_idx_m * out_stride_row;
let col_index = col; // Not sure
let results_pos_m = res_idx_m * Comptime::runtime(tile_size); // Times tile_size, or no because vectorized?
if Comptime::get(check_m_bounds) {
// // TODO: Not sure if necessary. SM already padded if overflowing
// if Comptime::get(check_n_bounds) {
// within_output = within_output && col_index < kernel_state.dim_n;
// }
if row_index < kernel_state.dim_m {
// Warning: can't do the following:
// let mut out = kernel_state.out;
kernel_state.out[row_index + col_index] = results[results_pos_m];
}
} else {
kernel_state.out[row_index + col_index] = results[results_pos_m];
}
}
}
#[cube(launch)]
/// Kernel for tiling2d matmul
pub fn tiling2d_matmul_kernel<F: Float>(
lhs: Tensor<F>,
rhs: Tensor<F>,
@ -330,21 +399,21 @@ pub fn tiling2d_matmul_kernel<F: Float>(
config: Comptime<CubeTiling2dConfig>,
) {
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let kernel_state = gather_kernel_information::<F>(lhs, rhs, out, config);
let mut kernel_state = gather_kernel_information::<F>(lhs, rhs, out, config);
for i in range(0u32, kernel_state.n_loops, Comptime::new(false)) {
let k = i * Comptime::runtime(block_size_k);
kernel_state.k = i * Comptime::runtime(block_size_k);
load_shared_memory(kernel_state, config);
sync_units();
computation_loop(kernel_state);
computation_loop(kernel_state, config);
sync_units();
}
write_to_output(kernel_state);
write_to_output(kernel_state, config);
}
/// Matrix multiplication using tiling 2d algorithm with
@ -355,9 +424,9 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
out: JitTensor<R, E, D>,
config: Tiling2dConfig,
) -> JitTensor<R, E, D> {
// Bound checks can be done comptime specifically for all dims
// let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
let m = lhs.shape.dims[D - 2];
let k = lhs.shape.dims[D - 1];
let n = rhs.shape.dims[D - 1];
let client = lhs.client.clone();
@ -373,9 +442,9 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
let settings = KernelSettings::default()
.vectorize_input(0, 1)
.vectorize_input(1, 1)
.vectorize_output(0, 1);
.vectorize_input(0, 4)
.vectorize_input(1, 4)
.vectorize_output(0, 4);
tiling2d_matmul_kernel_launch::<E::CubeElement, R>(
client,
@ -384,7 +453,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
TensorHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
TensorHandle::new(&out.handle, &out.strides, &out.shape.dims),
CubeTiling2dConfig::new(config, k),
CubeTiling2dConfig::new(config, m, k, n),
);
out

View File

@ -88,6 +88,7 @@ where
}
let compile = kernel.compile();
println!("{}", compile.source);
let pipeline = self.compile_source(&compile.source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());