mirror of https://github.com/tracel-ai/burn.git
runs, but oh lord is it wrong
This commit is contained in:
parent
43cf06da18
commit
12843ed159
|
@ -1,5 +1,8 @@
|
|||
use crate::frontend::{CubeElem, CubeType, ExpandElement};
|
||||
use crate::ir::Elem;
|
||||
use crate::ir::{Elem, Operator};
|
||||
use crate::prelude::{init_expand, CubeContext};
|
||||
|
||||
use super::Init;
|
||||
|
||||
impl CubeType for bool {
|
||||
type ExpandType = ExpandElement;
|
||||
|
@ -10,3 +13,9 @@ impl CubeElem for bool {
|
|||
Elem::Bool
|
||||
}
|
||||
}
|
||||
|
||||
// impl Init for bool {
|
||||
// fn init(self, context: &mut CubeContext) -> Self {
|
||||
// init_expand(context, self, Operator::Assign)
|
||||
// }
|
||||
// }
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{fusion::kernel, kernel::into_contiguous, tensor::JitTensor, FloatElement, JitRuntime};
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
use super::{tiling2d_launch_options, tiling2d_shader::write_to_output, Tiling2dConfig};
|
||||
use super::{tiling2d_launch_options, Tiling2dConfig};
|
||||
|
||||
impl Init for CubeTiling2dConfig {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
|
@ -20,8 +20,12 @@ pub struct CubeTiling2dConfig {
|
|||
pub block_size_n: UInt,
|
||||
/// Loop unrolling
|
||||
pub unroll: bool,
|
||||
/// Bounds must be checked on lhs dimension
|
||||
pub check_m_bounds: bool,
|
||||
/// Bounds must be checked on common dimension
|
||||
pub check_k_bounds: bool,
|
||||
/// Bounds must be checked on rhs dimension
|
||||
pub check_n_bounds: bool,
|
||||
/// Shared memory size lhs: technically derivable from others, but needs comptime arithmetic
|
||||
pub sm_size_lhs: UInt,
|
||||
/// Shared memory size rhs: technically derivable from others, but needs comptime arithmetic
|
||||
|
@ -29,7 +33,7 @@ pub struct CubeTiling2dConfig {
|
|||
}
|
||||
|
||||
impl CubeTiling2dConfig {
|
||||
fn new(config: Tiling2dConfig, k: usize) -> Self {
|
||||
fn new(config: Tiling2dConfig, m: usize, k: usize, n: usize) -> Self {
|
||||
let tile_size = config.tile_size_m;
|
||||
let sm_size_lhs = config.block_size_m * config.block_size_k * tile_size;
|
||||
let sm_size_rhs = config.block_size_k * config.block_size_n * tile_size;
|
||||
|
@ -39,14 +43,16 @@ impl CubeTiling2dConfig {
|
|||
block_size_k: UInt::new(config.block_size_k as u32),
|
||||
block_size_n: UInt::new(config.block_size_n as u32),
|
||||
unroll: config.unroll,
|
||||
check_m_bounds: m % config.block_size_m != 0,
|
||||
check_k_bounds: k % config.block_size_k != 0,
|
||||
check_n_bounds: n % config.block_size_n != 0,
|
||||
sm_size_lhs: UInt::new(sm_size_lhs as u32),
|
||||
sm_size_rhs: UInt::new(sm_size_rhs as u32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
#[derive(CubeType, Copy, Clone)]
|
||||
struct Tiling2dState<F: Float> {
|
||||
pub n_loops: UInt,
|
||||
pub k: UInt,
|
||||
|
@ -65,8 +71,8 @@ struct Tiling2dState<F: Float> {
|
|||
pub unit_row: UInt,
|
||||
pub shared_lhs: SharedMemory<F>,
|
||||
pub shared_rhs: SharedMemory<F>,
|
||||
pub register_m: F,
|
||||
pub register_n: F,
|
||||
pub register_m: Array<F>,
|
||||
pub register_n: Array<F>,
|
||||
pub results: Array<F>,
|
||||
pub lhs_stride_col: UInt,
|
||||
pub lhs_stride_row: UInt,
|
||||
|
@ -136,12 +142,14 @@ fn gather_kernel_information<F: Float>(
|
|||
let mut offset_rhs = skip_col * rhs_stride_col;
|
||||
|
||||
// Batch offset for lhs, rhs
|
||||
for b in range(0, rank - UInt::new(2), unroll) {
|
||||
for b in range(0u32, rank - UInt::new(2), unroll) {
|
||||
let tmp = offset_output / out.stride(b);
|
||||
offset_lhs += tmp % lhs.shape(b) * lhs.stride(b);
|
||||
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);
|
||||
}
|
||||
|
||||
let register_m = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
let register_n = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
let results = Array::<F>::vectorized(Comptime::get(tile_size), Comptime::get(tile_size));
|
||||
|
||||
let shared_lhs =
|
||||
|
@ -159,10 +167,8 @@ fn gather_kernel_information<F: Float>(
|
|||
n_loops = dim_k / Comptime::runtime(block_size_k);
|
||||
}
|
||||
|
||||
// Dummy declarations
|
||||
// Dummy declaration
|
||||
let k = UInt::new(0);
|
||||
let register_m = F::new(0.);
|
||||
let register_n = F::new(0.);
|
||||
|
||||
Tiling2dState {
|
||||
n_loops,
|
||||
|
@ -253,14 +259,17 @@ fn load_tensor_with_checks<F: Float>(
|
|||
let tile_size = Comptime::vectorization(input);
|
||||
let check_k_bounds = Comptime::map(config, |c| c.check_k_bounds);
|
||||
|
||||
let mut n_writes = Comptime::runtime(tile_size);
|
||||
// TODO: direct assignation from Comptime::runtime gives a constant
|
||||
let n_writes_tmp = Comptime::runtime(tile_size);
|
||||
let mut n_writes = n_writes_tmp;
|
||||
|
||||
if Comptime::get(check_k_bounds) {
|
||||
n_writes = UInt::min(dim - pos_in_dim, Comptime::runtime(tile_size));
|
||||
}
|
||||
|
||||
// TODO we should avoid that if in no_check_bound version
|
||||
if n_writes >= UInt::new(1) {
|
||||
for j in range(0, Comptime::get(tile_size), unroll) {
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let current = unit_idx_1 + j;
|
||||
|
||||
if current + k < dim_k {
|
||||
|
@ -281,8 +290,11 @@ fn load_tensor_with_checks<F: Float>(
|
|||
// TODO simplify when stride_2 is 1, so we can leverage already vectorized
|
||||
let mut array = Array::<F>::new(Comptime::get(tile_size));
|
||||
|
||||
for i in range(0, n_writes, Comptime::new(false)) {
|
||||
array[i] = input[position_base + i * stride_2];
|
||||
for i in range(0u32, n_writes, Comptime::new(false)) {
|
||||
// Unvectorize
|
||||
// TODO: Should increment second [] if stride_2 is 1
|
||||
// Plus, other than 0s are unaccessible
|
||||
array[i] = input[position_base + i * stride_2][UInt::new(0)];
|
||||
}
|
||||
// Pad with zeros
|
||||
if Comptime::get(check_k_bounds) {
|
||||
|
@ -292,6 +304,7 @@ fn load_tensor_with_checks<F: Float>(
|
|||
}
|
||||
|
||||
// TODO could tile_size be fetched from array length?
|
||||
// TODO make sure what we write works with what is now read in computation loop
|
||||
shared_memory[sm_position] = array.to_vectorized(tile_size);
|
||||
}
|
||||
}
|
||||
|
@ -308,21 +321,77 @@ fn computation_loop<F: Float>(
|
|||
let shared_lhs = kernel_state.shared_lhs;
|
||||
let shared_rhs = kernel_state.shared_rhs;
|
||||
let mut register_m = kernel_state.register_m;
|
||||
let register_n = kernel_state.register_n;
|
||||
let results = kernel_state.results;
|
||||
let mut register_n = kernel_state.register_n;
|
||||
let mut results = kernel_state.results;
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let vectorization = Comptime::vectorization(kernel_state.lhs);
|
||||
let tile_size = Comptime::vectorization(kernel_state.lhs);
|
||||
|
||||
for dot_index in range(0, Comptime::get(block_size_k), unroll) {
|
||||
register_m = shared_lhs[unit_row / Comptime::runtime(vectorization)
|
||||
* Comptime::runtime(block_size_k)
|
||||
+ dot_index];
|
||||
for dot_index in range(0u32, Comptime::get(block_size_k), unroll) {
|
||||
let lhs_pos =
|
||||
unit_row / Comptime::runtime(tile_size) * Comptime::runtime(block_size_k) + dot_index;
|
||||
let rhs_pos =
|
||||
(dot_index * Comptime::runtime(block_size_n) + unit_col) / Comptime::runtime(tile_size);
|
||||
|
||||
// Get a tile
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let WHAT = UInt::new(0); // TODO of course
|
||||
register_m[i] = shared_lhs[lhs_pos + i * WHAT];
|
||||
register_n[i] = shared_rhs[rhs_pos + i * WHAT];
|
||||
}
|
||||
|
||||
// Replaceable with tensor core call
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let row = register_m[res_idx_m];
|
||||
let pos_m = res_idx_m * Comptime::runtime(tile_size);
|
||||
for res_idx_n in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let col = register_n[res_idx_n];
|
||||
results[pos_m + res_idx_n] += row * col;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn write_to_output<F: Float>(
|
||||
mut kernel_state: Tiling2dState<F>,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
// No bounds check version
|
||||
let row = kernel_state.row;
|
||||
let col = kernel_state.col;
|
||||
let out_stride_row = kernel_state.out_stride_row;
|
||||
let out_stride_col = kernel_state.out_stride_col;
|
||||
let results = kernel_state.results;
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let check_m_bounds = Comptime::map(config, |c| c.check_m_bounds);
|
||||
let check_n_bounds = Comptime::map(config, |c| c.check_n_bounds);
|
||||
let tile_size = Comptime::vectorization(kernel_state.lhs);
|
||||
|
||||
for res_idx_m in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let row_index = row + res_idx_m * out_stride_row;
|
||||
let col_index = col; // Not sure
|
||||
let results_pos_m = res_idx_m * Comptime::runtime(tile_size); // Times tile_size, or no because vectorized?
|
||||
|
||||
if Comptime::get(check_m_bounds) {
|
||||
// // TODO: Not sure if necessary. SM already padded if overflowing
|
||||
// if Comptime::get(check_n_bounds) {
|
||||
// within_output = within_output && col_index < kernel_state.dim_n;
|
||||
// }
|
||||
if row_index < kernel_state.dim_m {
|
||||
// Warning: can't do the following:
|
||||
// let mut out = kernel_state.out;
|
||||
kernel_state.out[row_index + col_index] = results[results_pos_m];
|
||||
}
|
||||
} else {
|
||||
kernel_state.out[row_index + col_index] = results[results_pos_m];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
/// Kernel for tiling2d matmul
|
||||
pub fn tiling2d_matmul_kernel<F: Float>(
|
||||
lhs: Tensor<F>,
|
||||
rhs: Tensor<F>,
|
||||
|
@ -330,21 +399,21 @@ pub fn tiling2d_matmul_kernel<F: Float>(
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let kernel_state = gather_kernel_information::<F>(lhs, rhs, out, config);
|
||||
let mut kernel_state = gather_kernel_information::<F>(lhs, rhs, out, config);
|
||||
|
||||
for i in range(0u32, kernel_state.n_loops, Comptime::new(false)) {
|
||||
let k = i * Comptime::runtime(block_size_k);
|
||||
kernel_state.k = i * Comptime::runtime(block_size_k);
|
||||
|
||||
load_shared_memory(kernel_state, config);
|
||||
|
||||
sync_units();
|
||||
|
||||
computation_loop(kernel_state);
|
||||
computation_loop(kernel_state, config);
|
||||
|
||||
sync_units();
|
||||
}
|
||||
|
||||
write_to_output(kernel_state);
|
||||
write_to_output(kernel_state, config);
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with
|
||||
|
@ -355,9 +424,9 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
out: JitTensor<R, E, D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> JitTensor<R, E, D> {
|
||||
// Bound checks can be done comptime specifically for all dims
|
||||
// let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
|
||||
let m = lhs.shape.dims[D - 2];
|
||||
let k = lhs.shape.dims[D - 1];
|
||||
let n = rhs.shape.dims[D - 1];
|
||||
|
||||
let client = lhs.client.clone();
|
||||
|
||||
|
@ -373,9 +442,9 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
let cube_count = tiling2d_launch_options(&out.shape, config.clone());
|
||||
|
||||
let settings = KernelSettings::default()
|
||||
.vectorize_input(0, 1)
|
||||
.vectorize_input(1, 1)
|
||||
.vectorize_output(0, 1);
|
||||
.vectorize_input(0, 4)
|
||||
.vectorize_input(1, 4)
|
||||
.vectorize_output(0, 4);
|
||||
|
||||
tiling2d_matmul_kernel_launch::<E::CubeElement, R>(
|
||||
client,
|
||||
|
@ -384,7 +453,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
TensorHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
TensorHandle::new(&out.handle, &out.strides, &out.shape.dims),
|
||||
CubeTiling2dConfig::new(config, k),
|
||||
CubeTiling2dConfig::new(config, m, k, n),
|
||||
);
|
||||
|
||||
out
|
||||
|
|
|
@ -88,6 +88,7 @@ where
|
|||
}
|
||||
|
||||
let compile = kernel.compile();
|
||||
println!("{}", compile.source);
|
||||
let pipeline = self.compile_source(&compile.source);
|
||||
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
|
Loading…
Reference in New Issue