diff --git a/backend-comparison/benches/matmul.rs b/backend-comparison/benches/matmul.rs index 15941a61b..b3e9dfdcf 100644 --- a/backend-comparison/benches/matmul.rs +++ b/backend-comparison/benches/matmul.rs @@ -52,10 +52,10 @@ fn bench( token: Option<&str>, ) { const D: usize = 3; - let batch_size = 4048; - let m = 320; - let k = 4; - let n = 324; + let batch_size = 1000; + let m = 256; + let k = 512; + let n = 256; let shape_lhs = [batch_size, m, k].into(); let shape_rhs = [batch_size, k, n].into(); diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index cdcfcc561..fdb2f7903 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -108,7 +108,7 @@ pub enum MatmulStrategy { #[cfg(feature = "autotune")] impl Default for MatmulStrategy { fn default() -> Self { - MatmulStrategy::Tiling2dCube(Tiling2dConfig::default()) + MatmulStrategy::Autotune } } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs index d81cd54c7..570e2bfe7 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/mod.rs @@ -9,7 +9,10 @@ mod tiling2d_core; mod write_output; pub use base::matmul_tiling_2d_cube; -pub use compute_loop::tests as compute_loop_tests; -pub use load_shared_memory::tests as load_shared_memory_tests; -pub use outer_product::tests as outer_product_tests; -pub use write_output::tests as write_output_tests; + +#[cfg(feature = "export_tests")] +pub use { + compute_loop::tests as compute_loop_tests, + load_shared_memory::tests as load_shared_memory_tests, + outer_product::tests as outer_product_tests, write_output::tests as write_output_tests, +};