mirror of https://github.com/tracel-ai/burn.git
in search of the bug
This commit is contained in:
parent
081fd782af
commit
0dbc3c5af4
|
@ -17,18 +17,27 @@ pub(crate) fn compute_loop<F: Float>(
|
|||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
// let unroll = Comptime::map(config, |c| c.unroll);
|
||||
|
||||
let unit_row = coordinates.unit_row;
|
||||
let unit_col = coordinates.unit_col;
|
||||
|
||||
for dot_index in range(0u32, block_size_k, unroll) {
|
||||
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) {
|
||||
let lhs_index =
|
||||
(unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size);
|
||||
let mut register_m = F::vectorized(0., Comptime::get(tile_size));
|
||||
if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
||||
register_m = shared_lhs[lhs_index];
|
||||
}
|
||||
|
||||
let rhs_index =
|
||||
(unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size);
|
||||
let mut register_n = F::vectorized(0., Comptime::get(tile_size));
|
||||
if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
||||
register_n = shared_rhs[rhs_index];
|
||||
}
|
||||
|
||||
tile_outer_product::<F>(register_m, register_n, results, config);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
use crate::tensor;
|
||||
|
||||
use super::{base::Coordinates, config::CubeTiling2dConfig};
|
||||
|
||||
#[cube]
|
||||
|
@ -172,7 +174,12 @@ fn write_tile_plain<F: Float>(
|
|||
let sm_vectorization = Comptime::runtime(tile_size);
|
||||
|
||||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i];
|
||||
if i < Comptime::runtime(tile_size) {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i];
|
||||
} else {
|
||||
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] =
|
||||
F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,7 +201,15 @@ fn write_tile_transposed<F: Float>(
|
|||
for i in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
let mut transposed = F::vectorized(0., Comptime::get(tile_size));
|
||||
for j in range(0u32, Comptime::get(tile_size), unroll) {
|
||||
transposed[j] = tile[j][i];
|
||||
let mut row = tile[j];
|
||||
if j > Comptime::runtime(tile_size) {
|
||||
row = F::vectorized(0., Comptime::get(tile_size));
|
||||
}
|
||||
let mut elem = row[i];
|
||||
if i > Comptime::runtime(tile_size) {
|
||||
elem = F::new(0.);
|
||||
}
|
||||
transposed[j] = elem;
|
||||
}
|
||||
|
||||
let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization;
|
||||
|
@ -222,6 +237,11 @@ fn read_with_both_checks<F: Float>(
|
|||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_with_checks::<F>(
|
||||
tensor,
|
||||
|
@ -259,6 +279,11 @@ fn read_with_vertical_checks<F: Float>(
|
|||
num_reads = UInt::min(dim_vertical - row, tile_size_runtime);
|
||||
}
|
||||
|
||||
let zeros = F::vectorized(0., Comptime::get(tile_size));
|
||||
for i in range(0u32, Comptime::get(tile_size), Comptime::new(false)) {
|
||||
tile[i] = zeros;
|
||||
}
|
||||
|
||||
for i in range(0u32, num_reads, Comptime::new(false)) {
|
||||
read_tile_line_without_checks::<F>(
|
||||
tensor,
|
||||
|
@ -345,7 +370,11 @@ fn read_tile_line_with_checks<F: Float>(
|
|||
if col >= dim_horizontal {
|
||||
tile[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
} else {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
if position / runtime_vectorization >= tensor.len() {
|
||||
tile[i] = F::vectorized(0., Comptime::get(tile_size));
|
||||
} else {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let tile_entry = F::vectorized(0., Comptime::get(tile_size));
|
||||
|
@ -387,7 +416,11 @@ fn read_tile_line_without_checks<F: Float>(
|
|||
let position = position_base + i * stride;
|
||||
|
||||
if tile_size == vectorization_factor {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
if position / runtime_vectorization >= tensor.len() {
|
||||
tile[i] = F::vectorized(0., Comptime::get(vectorization_factor));
|
||||
} else {
|
||||
tile[i] = tensor[position / runtime_vectorization];
|
||||
}
|
||||
} else {
|
||||
let tile_entry = F::vectorized(0., Comptime::get(tile_size));
|
||||
|
||||
|
@ -424,9 +457,16 @@ fn read_within_vector<F: Float>(
|
|||
let runtime_vectorization = Comptime::runtime(vectorization_factor);
|
||||
|
||||
if Comptime::get(is_scalar) {
|
||||
tile_entry[i] = tensor[position + i];
|
||||
if position + i >= tensor.len() {
|
||||
tile_entry[i] = F::new(0.);
|
||||
} else {
|
||||
tile_entry[i] = tensor[position + i];
|
||||
}
|
||||
} else {
|
||||
let intermediate = tensor[position / runtime_vectorization + i];
|
||||
let mut intermediate = F::vectorized(0., Comptime::get(vectorization_factor));
|
||||
if position / runtime_vectorization + i < tensor.len() {
|
||||
intermediate = tensor[position / runtime_vectorization + i];
|
||||
}
|
||||
|
||||
for j in range(0u32, Comptime::get(vectorization_factor), unroll) {
|
||||
tile_entry[i * runtime_vectorization + j] = intermediate[j];
|
||||
|
|
|
@ -148,7 +148,12 @@ fn write_within_vector<F: Float>(
|
|||
output_elem[j] = results[results_pos_m + index];
|
||||
}
|
||||
|
||||
out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem;
|
||||
if i + out_position / Comptime::runtime(vectorization_factor) < out.len() {
|
||||
out[i + out_position / Comptime::runtime(vectorization_factor)] = output_elem;
|
||||
} else {
|
||||
out[i + out_position / Comptime::runtime(vectorization_factor)] =
|
||||
F::vectorized(9999., Comptime::get(vectorization_factor));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -96,21 +96,25 @@ where
|
|||
}
|
||||
|
||||
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
||||
let module = self.device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
|
||||
Arc::new(
|
||||
self.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
unsafe {
|
||||
let module = self
|
||||
.device
|
||||
.create_shader_module_unchecked(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
compilation_options: Default::default(),
|
||||
}),
|
||||
)
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
|
||||
Arc::new(
|
||||
self.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
compilation_options: Default::default(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
|
||||
|
|
Loading…
Reference in New Issue