From 3f0f1f23fd7bdcb0ef9dfba70a8ec6aeadd4f4a4 Mon Sep 17 00:00:00 2001 From: louisfd Date: Mon, 15 Jul 2024 15:55:45 -0400 Subject: [PATCH] debugging --- .../src/kernel/matmul/cmma/compute_loop.rs | 28 +++++++++++-------- .../src/kernel/matmul/cmma/write_output.rs | 14 +++++++--- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/cmma/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/cmma/compute_loop.rs index 787618385..fef83f819 100644 --- a/crates/burn-jit/src/kernel/matmul/cmma/compute_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/cmma/compute_loop.rs @@ -54,6 +54,7 @@ pub(crate) fn compute_loop( .rhs .slice(shared_rhs_pos, shared_rhs_pos + num_tile_elems); + // cmma_computation(lhs_slice, rhs_slice, accumulate_slice); cmma_row_major_mimic(lhs_slice, rhs_slice, accumulate_slice); } } @@ -422,13 +423,15 @@ pub mod tests { 3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0, 3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0, ]; - assert_equals::(results, expected, device); + assert_equals_range::(results, expected, 768..1024, device); } /// Exported test pub fn compute_loop_k_test(device: &R::Device) { - let lhs = range_tensor::(16, 32, device); - let rhs = range_tensor::(32, 16, device); + type FC1 = f32; + type FC2 = F32; + let lhs = range_tensor_generic::(16, 32, device); + let rhs = range_tensor_generic::(32, 16, device); let results = create_empty::(16, 16, device); let cube_dim = CubeDim::new(1, 32, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -447,16 +450,16 @@ pub mod tests { unroll: false, }; - compute_loop_test::launch::( + compute_loop_test::launch::( lhs.client.clone(), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ArrayArg::new(&results, 256), - UInt::new(16), + UInt::new(64), UInt::new(32), - UInt::new(16), + UInt::new(64), config, ); @@ -499,8 +502,11 @@ pub mod tests { /// Exported test pub fn compute_loop_warp_test(device: &R::Device) { - let lhs = range_tensor::(16, 32, device); - let rhs = range_tensor::(32, 32, device); + type FC1 = f32; + type FC2 = F32; + + let lhs = range_tensor_generic::(16, 32, device); + let rhs = range_tensor_generic::(32, 32, device); let results = create_empty::(16, 32, device); let cube_dim = CubeDim::new(1, 32, 1); let cube_count = CubeCount::Static(1, 1, 1); @@ -519,16 +525,16 @@ pub mod tests { unroll: false, }; - compute_loop_test::launch::( + compute_loop_test::launch::( lhs.client.clone(), cube_count, cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), ArrayArg::new(&results, 512), - UInt::new(16), - UInt::new(32), + UInt::new(64), UInt::new(32), + UInt::new(64), config, ); diff --git a/crates/burn-jit/src/kernel/matmul/cmma/write_output.rs b/crates/burn-jit/src/kernel/matmul/cmma/write_output.rs index d7215a295..c0a62f343 100644 --- a/crates/burn-jit/src/kernel/matmul/cmma/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/cmma/write_output.rs @@ -105,6 +105,11 @@ pub mod tests { k: UInt::new(0), }; + let out_vec = Comptime::vectorization(out); + for i in range(0u32, (k * n)/Comptime::runtime(out_vec), Comptime::new(false)) { + out[i] = F::vectorized(0., Comptime::get(out_vec)); + } + let mut accumulate = SharedMemory::::new(4096); for i in range(0u32, 4096u32, Comptime::new(false)) { accumulate[i] = acc_sm_arr[i]; @@ -123,7 +128,8 @@ pub mod tests { pub fn cmma_write_output_unit_test(device: &R::Device) { let k = 16; let n = 32; - let out = zeros_tensor::(k, n, device); + // TODO should be zeros_tensor, rather than range then put back to 0, but fails on cuda + let out = range_tensor::(k, n, device); let acc_sm = range_tensor::(64, 64, device); let cube_dim = CubeDim::new(1, 1, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -193,7 +199,7 @@ pub mod tests { pub fn cmma_write_output_warp_test(device: &R::Device) { let k = 16; let n = 32; - let out = zeros_tensor::(k, n, device); + let out = range_tensor::(k, n, device); let acc_sm = range_tensor::(64, 64, device); let cube_dim = CubeDim::new(1, 32, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -274,7 +280,7 @@ pub mod tests { pub fn cmma_write_output_second_warp_test(device: &R::Device) { let k = 16; let n = 64; - let out = zeros_tensor::(k, n, device); + let out = range_tensor::(k, n, device); let acc_sm = range_tensor::(64, 64, device); let cube_dim = CubeDim::new(2, 32, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1); @@ -398,7 +404,7 @@ pub mod tests { pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) { let k = 32; let n = 64; - let out = zeros_tensor::(k, n, device); + let out = range_tensor::(k, n, device); let acc_sm = range_tensor::(64, 64, device); let cube_dim = CubeDim::new(4, 32, 1); let cube_count: CubeCount = CubeCount::Static(1, 1, 1);