in search of the bug

This commit is contained in:
louisfd 2024-06-28 09:27:05 -04:00
parent 081fd782af
commit 0dbc3c5af4
4 changed files with 86 additions and 28 deletions

View File

@ -17,18 +17,27 @@ pub(crate) fn compute_loop<F: Float>(
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_n = Comptime::map(config, |c| c.block_size_n);
let unroll = Comptime::map(config, |c| c.unroll);
// let unroll = Comptime::map(config, |c| c.unroll);
let unit_row = coordinates.unit_row;
let unit_col = coordinates.unit_col;
for dot_index in range(0u32, block_size_k, unroll) {
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
/ Comptime::runtime(tile_size)];
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
/ Comptime::runtime(tile_size)];
for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) {
let lhs_index =
(unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size);
let mut register_m = F::vectorized(0., Comptime::get(tile_size));
if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
register_m = shared_lhs[lhs_index];
}
let rhs_index =
(unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size);
let mut register_n = F::vectorized(0., Comptime::get(tile_size));
if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
register_n = shared_rhs[rhs_index];
}
tile_outer_product::<F>(register_m, register_n, results, config);
}

View File

@ -1,5 +1,7 @@
use burn_cube::prelude::*;
use crate::tensor;
use super::{base::Coordinates, config::CubeTiling2dConfig};
#[cube]
@ -172,7 +174,12 @@ fn write_tile_plain<F: Float>(
let sm_vectorization = Comptime::runtime(tile_size);
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i];
if i < Comptime::runtime(tile_size) {
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i];
} else {
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] =
F::vectorized(0., Comptime::get(tile_size));
}
}
}
@ -194,7 +201,15 @@ fn write_tile_transposed<F: Float>(
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = F::vectorized(0., Comptime::get(tile_size));
for j in range(0u32, Comptime::get(tile_size), unroll) {
transposed[j] = tile[j][i];
let mut row = tile[j];
if j > Comptime::runtime(tile_size) {
row = F::vectorized(0., Comptime::get(tile_size));
}
let mut elem = row[i];
if i > Comptime::runtime(tile_size) {
elem = F::new(0.);
}
transposed[j] = elem;
}
let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization;
@ -222,6 +237,11 @@ fn read_with_both_checks<F: Float>(
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_with_checks::<F>(
tensor,
@ -259,6 +279,11 @@ fn read_with_vertical_checks<F: Float>(
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
}
let zeros = F::vectorized(0., Comptime::get(tile_size));
for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) {
tile[i] = zeros;
}
for i in range(0u32, num_reads, Comptime::new(false)) {
read_tile_line_without_checks::<F>(
tensor,
@ -345,7 +370,11 @@ fn read_tile_line_with_checks<F: Float>(
if col >= dim_horizontal {
tile[i] = F::vectorized(0., Comptime::get(tile_size));
} else {
tile[i] = tensor[position / runtime_vectorization];
if position / runtime_vectorization >= tensor.len() {
tile[i] = F::vectorized(0., Comptime::get(tile_size));
} else {
tile[i] = tensor[position / runtime_vectorization];
}
}
} else {
let tile_entry = F::vectorized(0., Comptime::get(tile_size));
@ -387,7 +416,11 @@ fn read_tile_line_without_checks<F: Float>(
let position = position_base + i * stride;
if tile_size == vectorization_factor {
tile[i] = tensor[position / runtime_vectorization];
if position / runtime_vectorization >= tensor.len() {
tile[i] = F::vectorized(0., Comptime::get(vectorization_factor));
} else {
tile[i] = tensor[position / runtime_vectorization];
}
} else {
let tile_entry = F::vectorized(0., Comptime::get(tile_size));
@ -424,9 +457,16 @@ fn read_within_vector<F: Float>(
let runtime_vectorization = Comptime::runtime(vectorization_factor);
if Comptime::get(is_scalar) {
tile_entry[i] = tensor[position + i];
if position + i >= tensor.len() {
tile_entry[i] = F::new(0.);
} else {
tile_entry[i] = tensor[position + i];
}
} else {
let intermediate = tensor[position / runtime_vectorization + i];
let mut intermediate = F::vectorized(0., Comptime::get(vectorization_factor));
if position / runtime_vectorization + i < tensor.len() {
intermediate = tensor[position / runtime_vectorization + i];
}
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
tile_entry[i * runtime_vectorization + j] = intermediate[j];

View File

@ -148,7 +148,12 @@ fn write_within_vector<F: Float>(
output_elem[j] = results[results_pos_m + index];
}
out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem;
if i + out_position / Comptime::runtime(vectorization_factor) < out.len() {
out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem;
} else {
out[i + out_position / Comptime::runtime(vectorization_factor)] =
F::vectorized(9999., Comptime::get(vectorization_factor));
}
}
}

View File

@ -96,21 +96,25 @@ where
}
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
let module = self.device.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
});
Arc::new(
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
unsafe {
let module = self
.device
.create_shader_module_unchecked(ShaderModuleDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
}),
)
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
});
Arc::new(
self.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &module,
entry_point: "main",
compilation_options: Default::default(),
}),
)
}
}
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {