From b2bd0fb31383e692193c65461a0e00aabc1099d4 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 28 Jun 2024 14:19:21 -0400 Subject: [PATCH] wip --- .../matmul/tiling2d_cube/compute_loop.rs | 45 +++++++++---------- crates/burn-wgpu/src/compute/server.rs | 32 ++++++------- 2 files changed, 36 insertions(+), 41 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs index 5348ab57f..cbd2fb8cc 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs @@ -17,27 +17,18 @@ pub(crate) fn compute_loop( ) { let tile_size = Comptime::map(config, |c| c.tile_size); let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::map(config, |c| c.block_size_k); + let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); let block_size_n = Comptime::map(config, |c| c.block_size_n); - // let unroll = Comptime::map(config, |c| c.unroll); + let unroll = Comptime::map(config, |c| c.unroll); let unit_row = coordinates.unit_row; let unit_col = coordinates.unit_col; - for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) { - let lhs_index = - (unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size); - let mut register_m = F::vectorized(0., Comptime::get(tile_size)); - if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) { - register_m = shared_lhs[lhs_index]; - } - - let rhs_index = - (unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size); - let mut register_n = F::vectorized(0., Comptime::get(tile_size)); - if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) { - register_n = shared_rhs[rhs_index]; - } + for dot_index in range(0u32, block_size_k, unroll) { + let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m)) + / Comptime::runtime(tile_size)]; + let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n)) + / Comptime::runtime(tile_size)]; tile_outer_product::(register_m, register_n, results, config); } @@ -138,8 +129,8 @@ pub mod tests { /// Exported test pub fn compute_loop_unit_offset_test(device: &R::Device) { - let lhs = range_tensor_transposed::(8, 4, device); - let rhs = range_tensor::(4, 8, device); + let shared_lhs = range_tensor::(4, 8, device); + let shared_rhs = range_tensor::(4, 8, device); let results = create_empty::(TILE_SIZE, TILE_SIZE, device); let cube_dim = CubeDim::new(1, 1, 1); let cube_count = CubeCount::new(1, 1, 1); @@ -152,11 +143,19 @@ pub mod tests { let config = make_config(4, 8, 4); compute_loop_test_launch::( - lhs.client.clone(), + shared_lhs.client.clone(), cube_count, settings, - TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims), - TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims), + TensorHandle::new( + &shared_lhs.handle, + &shared_lhs.strides, + &shared_lhs.shape.dims, + ), + TensorHandle::new( + &shared_rhs.handle, + &shared_rhs.strides, + &shared_rhs.shape.dims, + ), 4, 4, ArrayHandle::new(&results, 1), @@ -164,8 +163,8 @@ pub mod tests { ); let expected = &[ - 1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0, - 1978.0, 1928.0, 2046.0, 2164.0, 2282.0, + 1344.0, 1408.0, 1472.0, 1536.0, 1408.0, 1476.0, 1544.0, 1612.0, 1472.0, 1544.0, 1616.0, + 1688.0, 1536.0, 1612.0, 1688.0, 1764.0, ]; assert_equals::(results, expected, device); } diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index 3404cd6b1..a03eb74e3 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -96,25 +96,21 @@ where } fn compile_source(&self, source: &str) -> Arc { - unsafe { - let module = self - .device - .create_shader_module_unchecked(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); + let module = self.device.create_shader_module(ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }); - Arc::new( - self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - compilation_options: Default::default(), - }), - ) - } + Arc::new( + self.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + compilation_options: Default::default(), + }), + ) } fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader {