diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 62c35b64d..4b12f40fc 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -53,9 +53,9 @@ fn bench( ) { const D: usize = 3; let batch_size = 32; - let m = 256; - let k = 1024; - let n = 256; + let m = 1024; + let k = 256; + let n = 1024; let shape_lhs = [batch_size, m, k].into(); let shape_rhs = [batch_size, k, n].into(); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 9643c0a0b..7a2c4292b 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -31,10 +31,7 @@ pub enum MatmulStrategy { #[cfg(feature = "autotune")] impl Default for MatmulStrategy { fn default() -> Self { - MatmulStrategy::Simple { - grid_x: 16, - grid_y: 16, - } + MatmulStrategy::Autotune } } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs index 8879bdcc6..1babc6300 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/base.rs @@ -193,6 +193,8 @@ pub fn matmul_tiling_2d_cube( } => (tensor, transposed), MemoryLayout::HighlyPermuted => (into_contiguous(tensor), false), }; + + // let check_layout = |tensor: JitTensor| (into_contiguous(tensor), false); let (lhs, lhs_transposed) = check_layout(lhs); let (rhs, rhs_transposed) = check_layout(rhs);