From 0dbc3c5af4695e062566d4a31661c01039256917 Mon Sep 17 00:00:00 2001 From: louisfd Date: Fri, 28 Jun 2024 09:27:05 -0400 Subject: [PATCH] in search of the bug --- .../matmul/tiling2d_cube/compute_loop.rs | 23 +++++--- .../tiling2d_cube/load_shared_memory.rs | 52 ++++++++++++++++--- .../matmul/tiling2d_cube/write_output.rs | 7 ++- crates/burn-wgpu/src/compute/server.rs | 32 +++++++----- 4 files changed, 86 insertions(+), 28 deletions(-) diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs index 4e636e8ca..5348ab57f 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/compute_loop.rs @@ -17,18 +17,27 @@ pub(crate) fn compute_loop( ) { let tile_size = Comptime::map(config, |c| c.tile_size); let block_size_m = Comptime::map(config, |c| c.block_size_m); - let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k)); + let block_size_k = Comptime::map(config, |c| c.block_size_k); let block_size_n = Comptime::map(config, |c| c.block_size_n); - let unroll = Comptime::map(config, |c| c.unroll); + // let unroll = Comptime::map(config, |c| c.unroll); let unit_row = coordinates.unit_row; let unit_col = coordinates.unit_col; - for dot_index in range(0u32, block_size_k, unroll) { - let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m)) - / Comptime::runtime(tile_size)]; - let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n)) - / Comptime::runtime(tile_size)]; + for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) { + let lhs_index = + (unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size); + let mut register_m = F::vectorized(0., Comptime::get(tile_size)); + if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) { + register_m = shared_lhs[lhs_index]; + } + + let rhs_index = + (unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size); + let mut register_n = F::vectorized(0., Comptime::get(tile_size)); + if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) { + register_n = shared_rhs[rhs_index]; + } tile_outer_product::(register_m, register_n, results, config); } diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs index 7ae3e49fa..4bd015c6b 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs @@ -1,5 +1,7 @@ use burn_cube::prelude::*; +use crate::tensor; + use super::{base::Coordinates, config::CubeTiling2dConfig}; #[cube] @@ -172,7 +174,12 @@ fn write_tile_plain( let sm_vectorization = Comptime::runtime(tile_size); for i in range(0u32, Comptime::get(tile_size), unroll) { - shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i]; + if i < Comptime::runtime(tile_size) { + shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i]; + } else { + shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = + F::vectorized(0., Comptime::get(tile_size)); + } } } @@ -194,7 +201,15 @@ fn write_tile_transposed( for i in range(0u32, Comptime::get(tile_size), unroll) { let mut transposed = F::vectorized(0., Comptime::get(tile_size)); for j in range(0u32, Comptime::get(tile_size), unroll) { - transposed[j] = tile[j][i]; + let mut row = tile[j]; + if j > Comptime::runtime(tile_size) { + row = F::vectorized(0., Comptime::get(tile_size)); + } + let mut elem = row[i]; + if i > Comptime::runtime(tile_size) { + elem = F::new(0.); + } + transposed[j] = elem; } let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization; @@ -222,6 +237,11 @@ fn read_with_both_checks( num_reads = UInt::min(dim_vertical - row, tile_size_runtime); } + let zeros = F::vectorized(0., Comptime::get(tile_size)); + for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) { + tile[i] = zeros; + } + for i in range(0u32, num_reads, Comptime::new(false)) { read_tile_line_with_checks::( tensor, @@ -259,6 +279,11 @@ fn read_with_vertical_checks( num_reads = UInt::min(dim_vertical - row, tile_size_runtime); } + let zeros = F::vectorized(0., Comptime::get(tile_size)); + for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) { + tile[i] = zeros; + } + for i in range(0u32, num_reads, Comptime::new(false)) { read_tile_line_without_checks::( tensor, @@ -345,7 +370,11 @@ fn read_tile_line_with_checks( if col >= dim_horizontal { tile[i] = F::vectorized(0., Comptime::get(tile_size)); } else { - tile[i] = tensor[position / runtime_vectorization]; + if position / runtime_vectorization >= tensor.len() { + tile[i] = F::vectorized(0., Comptime::get(tile_size)); + } else { + tile[i] = tensor[position / runtime_vectorization]; + } } } else { let tile_entry = F::vectorized(0., Comptime::get(tile_size)); @@ -387,7 +416,11 @@ fn read_tile_line_without_checks( let position = position_base + i * stride; if tile_size == vectorization_factor { - tile[i] = tensor[position / runtime_vectorization]; + if position / runtime_vectorization >= tensor.len() { + tile[i] = F::vectorized(0., Comptime::get(vectorization_factor)); + } else { + tile[i] = tensor[position / runtime_vectorization]; + } } else { let tile_entry = F::vectorized(0., Comptime::get(tile_size)); @@ -424,9 +457,16 @@ fn read_within_vector( let runtime_vectorization = Comptime::runtime(vectorization_factor); if Comptime::get(is_scalar) { - tile_entry[i] = tensor[position + i]; + if position + i >= tensor.len() { + tile_entry[i] = F::new(0.); + } else { + tile_entry[i] = tensor[position + i]; + } } else { - let intermediate = tensor[position / runtime_vectorization + i]; + let mut intermediate = F::vectorized(0., Comptime::get(vectorization_factor)); + if position / runtime_vectorization + i < tensor.len() { + intermediate = tensor[position / runtime_vectorization + i]; + } for j in range(0u32, Comptime::get(vectorization_factor), unroll) { tile_entry[i * runtime_vectorization + j] = intermediate[j]; diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs index e0656b23e..65e5c7bb6 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/write_output.rs @@ -148,7 +148,12 @@ fn write_within_vector( output_elem[j] = results[results_pos_m + index]; } - out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem; + if i + out_position / Comptime::runtime(vectorization_factor) < out.len() { + out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem; + } else { + out[i + out_position / Comptime::runtime(vectorization_factor)] = + F::vectorized(9999., Comptime::get(vectorization_factor)); + } } } diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index a03eb74e3..3404cd6b1 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -96,21 +96,25 @@ where } fn compile_source(&self, source: &str) -> Arc { - let module = self.device.create_shader_module(ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), - }); - - Arc::new( - self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + unsafe { + let module = self + .device + .create_shader_module_unchecked(ShaderModuleDescriptor { label: None, - layout: None, - module: &module, - entry_point: "main", - compilation_options: Default::default(), - }), - ) + source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), + }); + + Arc::new( + self.device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + compilation_options: Default::default(), + }), + ) + } } fn buffer_reader(&mut self, handle: server::Binding) -> BufferReader {