From e09d60fff09847f362a4e803b316313d89e7f15f Mon Sep 17 00:00:00 2001 From: louisfd Date: Sat, 29 Jun 2024 13:01:39 -0400 Subject: [PATCH] fix? --- .../tiling2d_cube/load_shared_memory.rs | 21 +- crates/burn-jit/src/tests/cube.wgsl | 236 +++++++++++ crates/burn-jit/src/tests/matmul.rs | 2 +- crates/burn-jit/src/tests/other.wgsl | 380 ++++++++++++++++++ crates/burn-wgpu/src/compute/server.rs | 1 + 5 files changed, 633 insertions(+), 7 deletions(-) create mode 100644 crates/burn-jit/src/tests/cube.wgsl create mode 100644 crates/burn-jit/src/tests/other.wgsl diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs index 2fb2dacbc..ef4b3ba4d 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d_cube/load_shared_memory.rs @@ -170,9 +170,13 @@ fn write_tile_plain( tile_size: Comptime, ) { let sm_vectorization = Comptime::runtime(tile_size); + let sm_len = UInt::new(2048) / sm_vectorization; for i in range(0u32, Comptime::get(tile_size), unroll) { - shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i]; + let sm_pos = (sm_position_base + i * sm_stride) / sm_vectorization; + if sm_pos < sm_len { + shared_memory[sm_pos] = tile[i]; + } } } @@ -185,11 +189,14 @@ fn write_tile_transposed( unroll: Comptime, tile_size: Comptime, ) { - let is_scalar = Comptime::map(tile_size, |c| c.val == 1); + let sm_is_scalar = Comptime::map(tile_size, |c| c.val == 1); let sm_vectorization = Comptime::runtime(tile_size); + let sm_len = UInt::new(2048) / sm_vectorization; - if Comptime::get(is_scalar) { - shared_memory[sm_position_base] = tile[0]; + if Comptime::get(sm_is_scalar) { + if sm_position_base < sm_len { + shared_memory[sm_position_base] = tile[0]; + } } else { for i in range(0u32, Comptime::get(tile_size), unroll) { let mut transposed = F::vectorized(0., Comptime::get(tile_size)); @@ -199,8 +206,10 @@ fn write_tile_transposed( transposed[j] = tile[j][i]; } - let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization; - shared_memory[sm_position] = transposed; + let sm_pos = (sm_position_base + i * sm_stride) / sm_vectorization; + if sm_pos < sm_len { + shared_memory[sm_pos] = transposed; + } } } } diff --git a/crates/burn-jit/src/tests/cube.wgsl b/crates/burn-jit/src/tests/cube.wgsl new file mode 100644 index 000000000..c62f0cb15 --- /dev/null +++ b/crates/burn-jit/src/tests/cube.wgsl @@ -0,0 +1,236 @@ +@group(0) +@binding(0) +var input_0_global: array>; + +@group(0) +@binding(1) +var input_1_global: array>; + +@group(0) +@binding(2) +var output_0_global: array>; + +@group(0) +@binding(3) +var info: array; + +var shared_memory_0: array, 512>; + +var shared_memory_1: array, 512>; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + var a_0_0: array; + + var a_0_1: array, 4>; + + var a_0_2: array, 4>; + + let rank: u32 = info[0]; + let rank_2: u32 = rank * 2u; + var l_0_0: u32; + var l_0_1: u32; + var l_0_2: u32; + var l_0_3: u32; + var l_0_4: u32; + var l_0_5: u32; + var l_0_6: u32; + var l_0_7: u32; + var l_0_8: u32; + var l_0_9: u32; + var l_0_10: u32; + var l_0_11: u32; + var l_0_12: u32; + var l_0_13: u32; + var l_0_14: vec4; + var l_0_15: vec4; + var l_0_16: f32; + var l_0_17: f32; + l_0_0 = 64u - 1u; + l_0_0 = l_0_0 / 4u; + l_0_0 = l_0_0 + 1u; + l_0_1 = workgroup_id.x * 64u; + l_0_2 = workgroup_id.y * 64u; + l_0_3 = local_idx / l_0_0; + l_0_3 = l_0_3 * 4u; + l_0_0 = local_idx % l_0_0; + l_0_0 = l_0_0 * 4u; + l_0_4 = u32(rank); + l_0_5 = l_0_4 - 2u; + l_0_6 = info[(0u * rank_2) + rank + l_0_5 + 1u]; + l_0_5 = l_0_4 - 1u; + l_0_7 = info[(1u * rank_2) + rank + l_0_5 + 1u]; + l_0_6 = l_0_6 * l_0_7; + l_0_6 = l_0_6 * workgroup_id.z; + l_0_7 = u32(0u); + l_0_5 = u32(0u); + l_0_4 = l_0_4 - 2u; + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_4; l_1_0++) { + l_0_8 = info[(2u * rank_2) + l_1_0 + 1u]; + l_0_8 = l_0_6 / l_0_8; + l_0_9 = info[(0u * rank_2) + rank + l_1_0 + 1u]; + l_0_9 = l_0_8 % l_0_9; + l_0_10 = info[(0u * rank_2) + l_1_0 + 1u]; + l_0_9 = l_0_9 * l_0_10; + l_0_7 = l_0_7 + l_0_9; + l_0_10 = info[(1u * rank_2) + rank + l_1_0 + 1u]; + l_0_8 = l_0_8 % l_0_10; + l_0_10 = info[(1u * rank_2) + l_1_0 + 1u]; + l_0_8 = l_0_8 * l_0_10; + l_0_5 = l_0_5 + l_0_8; + } + + for (var l_1_0: u32 = 0u; l_1_0 < 16u; l_1_0++) { + a_0_0[l_1_0] = f32(0f); + } + l_0_10 = rank - 1u; + l_0_9 = info[(0u * rank_2) + rank + l_0_10 + 1u]; + l_0_10 = u32(0u); + l_0_8 = l_0_9 / 32u; + l_0_10 = u32(l_0_8); + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_10; l_1_0++) { + l_0_9 = l_1_0 * 32u; + l_0_8 = l_0_0 * 64u; + l_0_8 = l_0_8 + l_0_3; + l_0_4 = rank - 2u; + l_0_11 = info[(0u * rank_2) + l_0_4 + 1u]; + l_0_11 = l_0_1 * l_0_11; + l_0_11 = l_0_11 + l_0_9; + l_0_11 = l_0_11 + l_0_7; + l_0_4 = u32(rank); + l_0_4 = l_0_4 - 2u; + l_0_12 = info[(0u * rank_2) + l_0_4 + 1u]; + l_0_4 = l_0_3 * l_0_12; + l_0_4 = l_0_4 + l_0_0; + l_0_4 = l_0_4 + l_0_11; + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_0_11 = l_2_0 * l_0_12; + l_0_11 = l_0_4 + l_0_11; + l_0_13 = l_0_11 / 4u; + l_0_14 = vec4(input_0_global[l_0_13]); + a_0_1[l_2_0] = vec4(l_0_14); + } + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_0_14[0u] = f32(0f); + l_0_14[1u] = f32(0f); + l_0_14[2u] = f32(0f); + l_0_14[3u] = f32(0f); + l_0_15 = vec4(a_0_1[0u]); + l_0_16 = f32(l_0_15[l_2_0]); + l_0_14[0u] = f32(l_0_16); + l_0_15 = vec4(a_0_1[1u]); + l_0_16 = f32(l_0_15[l_2_0]); + l_0_14[1u] = f32(l_0_16); + l_0_15 = vec4(a_0_1[2u]); + l_0_16 = f32(l_0_15[l_2_0]); + l_0_14[2u] = f32(l_0_16); + l_0_15 = vec4(a_0_1[3u]); + l_0_16 = f32(l_0_15[l_2_0]); + l_0_14[3u] = f32(l_0_16); + l_0_13 = l_2_0 * 64u; + l_0_13 = l_0_8 + l_0_13; + l_0_13 = l_0_13 / 4u; + shared_memory_0[l_0_13] = vec4(l_0_14); + } + l_0_13 = rank - 2u; + l_0_12 = info[(1u * rank_2) + l_0_13 + 1u]; + l_0_13 = l_0_3 * 64u; + l_0_13 = l_0_13 + l_0_0; + l_0_12 = l_0_9 * l_0_12; + l_0_12 = l_0_2 + l_0_12; + l_0_12 = l_0_12 + l_0_5; + l_0_11 = u32(rank); + l_0_11 = l_0_11 - 2u; + l_0_8 = info[(1u * rank_2) + l_0_11 + 1u]; + l_0_11 = l_0_3 * l_0_8; + l_0_11 = l_0_11 + l_0_0; + l_0_11 = l_0_11 + l_0_12; + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_0_12 = l_2_0 * l_0_8; + l_0_12 = l_0_11 + l_0_12; + l_0_4 = l_0_12 / 4u; + l_0_15 = vec4(input_1_global[l_0_4]); + a_0_2[l_2_0] = vec4(l_0_15); + } + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_0_12 = l_2_0 * 64u; + l_0_12 = l_0_13 + l_0_12; + l_0_12 = l_0_12 / 4u; + l_0_15 = vec4(a_0_2[l_2_0]); + shared_memory_1[l_0_12] = vec4(l_0_15); + } + workgroupBarrier(); + l_0_13 = u32(l_0_3); + l_0_12 = u32(l_0_0); + + for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) { + l_0_11 = l_2_0 * 64u; + l_0_11 = l_0_13 + l_0_11; + l_0_11 = l_0_11 / 4u; + l_0_15 = vec4(shared_memory_0[l_0_11]); + l_0_11 = l_2_0 * 64u; + l_0_11 = l_0_12 + l_0_11; + l_0_11 = l_0_11 / 4u; + l_0_14 = vec4(shared_memory_1[l_0_11]); + + for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) { + l_0_11 = l_3_0 * 4u; + + for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) { + l_0_16 = f32(l_0_15[l_3_0]); + l_0_17 = f32(l_0_14[l_4_0]); + l_0_16 = l_0_16 * l_0_17; + l_0_9 = l_0_11 + l_4_0; + l_0_17 = f32(a_0_0[l_0_9]); + l_0_17 = l_0_17 + l_0_16; + a_0_0[l_0_9] = f32(l_0_17); + } + } + } + workgroupBarrier(); + } + l_0_13 = l_0_1 + l_0_3; + l_0_0 = l_0_2 + l_0_0; + l_0_12 = rank - 2u; + l_0_11 = info[(2u * rank_2) + l_0_12 + 1u]; + + for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) { + l_0_12 = l_1_0 * 4u; + l_0_10 = l_0_13 + l_1_0; + l_0_10 = l_0_10 * l_0_11; + l_0_10 = l_0_10 + l_0_0; + l_0_10 = l_0_10 + l_0_6; + + for (var l_2_0: u32 = 0u; l_2_0 < 1u; l_2_0++) { + l_0_15[0u] = f32(0f); + l_0_15[1u] = f32(0f); + l_0_15[2u] = f32(0f); + l_0_15[3u] = f32(0f); + + for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) { + l_0_9 = l_2_0 * 4u; + l_0_9 = l_0_9 + l_3_0; + l_0_9 = l_0_12 + l_0_9; + l_0_17 = f32(a_0_0[l_0_9]); + l_0_15[l_3_0] = f32(l_0_17); + } + l_0_9 = l_0_10 / 4u; + l_0_9 = l_2_0 + l_0_9; + output_0_global[l_0_9] = vec4(l_0_15); + } + } +} \ No newline at end of file diff --git a/crates/burn-jit/src/tests/matmul.rs b/crates/burn-jit/src/tests/matmul.rs index 2e146ad57..4e73ea74e 100644 --- a/crates/burn-jit/src/tests/matmul.rs +++ b/crates/burn-jit/src/tests/matmul.rs @@ -432,7 +432,7 @@ mod tests { #[test] pub fn k_exceeds_block() { - test_with_params(64, 36, 32, 1, 1); + test_with_params(64, 32, 64, 1, 1); } #[test] diff --git a/crates/burn-jit/src/tests/other.wgsl b/crates/burn-jit/src/tests/other.wgsl new file mode 100644 index 000000000..95a582552 --- /dev/null +++ b/crates/burn-jit/src/tests/other.wgsl @@ -0,0 +1,380 @@ +@group(0) +@binding(0) +var input_0_global: array; + +@group(0) +@binding(1) +var input_1_global: array; + +@group(0) +@binding(2) +var output_0_global: array; + +@group(0) +@binding(3) +var info: array; + +var shared_memory_0: array, 512>; + +var shared_memory_1: array, 512>; + +const WORKGROUP_SIZE_X = 16u; +const WORKGROUP_SIZE_Y = 16u; +const WORKGROUP_SIZE_Z = 1u; + +@compute +@workgroup_size(16, 16, 1) +fn main( + @builtin(global_invocation_id) global_id: vec3, + @builtin(local_invocation_index) local_idx: u32, + @builtin(workgroup_id) workgroup_id: vec3, +) { + + var a_0_0: array; + + let rank: u32 = info[0]; + let rank_2: u32 = rank * 2u; + var l_0_0: u32; + var l_0_1: u32; + var l_0_2: u32; + var l_0_3: u32; + var l_0_4: u32; + var l_0_5: u32; + var l_0_6: u32; + var l_0_7: u32; + var l_0_8: u32; + var l_0_9: u32; + var l_0_10: u32; + var l_0_11: u32; + var l_0_12: u32; + var l_0_13: u32; + var l_0_14: u32; + var l_0_15: u32; + var l_0_16: u32; + var l_0_17: u32; + var l_0_18: u32; + var l_0_19: u32; + var l_0_20: u32; + var l_0_21: u32; + var l_0_22: u32; + var l_0_23: u32; + var l_0_24: u32; + var l_0_25: u32; + var l_0_26: u32; + var l_0_27: u32; + var l_0_28: u32; + var l_0_29: vec4; + var l_0_30: vec4; + var l_0_31: u32; + var l_0_32: u32; + var l_0_33: f32; + var l_0_34: f32; + var l_0_35: f32; + var l_0_36: u32; + var l_0_37: u32; + var l_0_38: bool; + var l_0_39: bool; + l_0_0 = rank - 1u; + l_0_1 = rank - 2u; + l_0_2 = info[(0u * rank_2) + rank + l_0_1 + 1u]; + l_0_3 = info[(0u * rank_2) + rank + l_0_0 + 1u]; + l_0_4 = info[(1u * rank_2) + rank + l_0_0 + 1u]; + l_0_5 = info[(0u * rank_2) + l_0_1 + 1u]; + l_0_6 = info[(0u * rank_2) + l_0_0 + 1u]; + l_0_7 = info[(1u * rank_2) + l_0_1 + 1u]; + l_0_8 = info[(1u * rank_2) + l_0_0 + 1u]; + l_0_9 = info[(2u * rank_2) + l_0_1 + 1u]; + l_0_10 = info[(2u * rank_2) + l_0_0 + 1u]; + l_0_11 = u32(workgroup_id.x); + l_0_11 = l_0_11 * 64u; + l_0_12 = u32(workgroup_id.y); + l_0_12 = l_0_12 * 64u; + l_0_13 = local_idx / 16u; + l_0_13 = l_0_13 * 4u; + l_0_14 = local_idx % 16u; + l_0_14 = l_0_14 * 4u; + l_0_15 = l_0_11 + l_0_13; + l_0_16 = l_0_12 + l_0_14; + l_0_17 = l_0_11 * l_0_5; + l_0_18 = l_0_12 * l_0_8; + l_0_19 = l_0_2 * l_0_4; + l_0_19 = l_0_19 * global_id.z; + l_0_20 = rank - 2u; + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_20; l_1_0++) { + l_0_21 = info[(0u * rank_2) + l_1_0 + 1u]; + l_0_22 = info[(1u * rank_2) + l_1_0 + 1u]; + l_0_23 = info[(2u * rank_2) + l_1_0 + 1u]; + l_0_24 = info[(0u * rank_2) + rank + l_1_0 + 1u]; + l_0_25 = info[(1u * rank_2) + rank + l_1_0 + 1u]; + l_0_26 = l_0_19 / l_0_23; + l_0_27 = l_0_26 % l_0_24; + l_0_27 = l_0_27 * l_0_21; + l_0_17 = l_0_17 + l_0_27; + l_0_28 = l_0_26 % l_0_25; + l_0_28 = l_0_28 * l_0_22; + l_0_18 = l_0_18 + l_0_28; + } + l_0_33 = f32(l_0_3); + l_0_34 = f32(32u); + l_0_35 = l_0_33 / l_0_34; + l_0_35 = ceil(l_0_35); + l_0_31 = u32(l_0_35); + + for (var l_1_0: u32 = 0u; l_1_0 < l_0_31; l_1_0++) { + var l_1_1: u32; + var l_1_2: u32; + var l_1_3: bool; + var l_1_4: u32; + var l_1_5: bool; + var l_1_6: u32; + var l_1_7: bool; + var l_1_8: bool; + var l_1_9: vec4; + var l_1_10: u32; + var l_1_11: u32; + var l_1_12: u32; + var l_1_13: u32; + var l_1_14: u32; + var l_1_15: bool; + var l_1_16: f32; + var l_1_17: f32; + var l_1_18: f32; + var l_1_19: f32; + var l_1_20: u32; + var l_1_21: u32; + var l_1_22: bool; + var l_1_23: u32; + var l_1_24: bool; + var l_1_25: u32; + var l_1_26: bool; + var l_1_27: bool; + var l_1_28: vec4; + var l_1_29: u32; + var l_1_30: u32; + var l_1_31: u32; + var l_1_32: u32; + var l_1_33: u32; + var l_1_34: bool; + var l_1_35: f32; + var l_1_36: f32; + var l_1_37: f32; + var l_1_38: f32; + var l_1_39: u32; + var l_1_40: u32; + var l_1_41: f32; + var l_1_42: f32; + var l_1_43: f32; + var l_1_44: u32; + var l_1_45: f32; + var l_1_46: f32; + l_0_32 = l_1_0 * 32u; + l_1_1 = l_0_2 - l_0_15; + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_1_2 = l_0_14 + l_2_0; + l_1_3 = l_1_2 < 32u; + if l_1_3 { + l_1_4 = l_0_13 / 4u; + l_1_4 = l_1_4 * 32u; + l_1_4 = l_1_4 + l_1_2; + l_1_6 = l_1_2 + l_0_32; + l_1_5 = l_1_6 < l_0_3; + l_1_7 = l_1_1 >= 1u; + l_1_8 = l_1_5 && l_1_7; + if l_1_8 { + var l_4_0: u32; + l_1_11 = l_0_32 + l_1_2; + l_1_11 = l_1_11 * l_0_6; + l_1_10 = l_0_13 * l_0_5; + l_1_11 = l_1_11 + l_1_10; + l_1_11 = l_1_11 + l_0_17; + l_1_12 = l_1_11 + l_0_5; + l_1_13 = l_1_12 + l_0_5; + l_1_14 = l_1_13 + l_0_5; + l_1_15 = l_1_1 >= 4u; + if l_1_15 { + l_1_16 = f32(input_0_global[l_1_11]); + l_1_17 = f32(input_0_global[l_1_12]); + l_1_18 = f32(input_0_global[l_1_13]); + l_1_19 = f32(input_0_global[l_1_14]); + } else { + l_1_15 = l_1_1 == 3u; + if l_1_15 { + l_1_16 = f32(input_0_global[l_1_11]); + l_1_17 = f32(input_0_global[l_1_12]); + l_1_18 = f32(input_0_global[l_1_13]); + l_1_19 = f32(0u); + } else { + l_1_15 = l_1_1 == 2u; + if l_1_15 { + l_1_16 = f32(input_0_global[l_1_11]); + l_1_17 = f32(input_0_global[l_1_12]); + l_1_18 = f32(0u); + l_1_19 = f32(0u); + } else { + l_1_15 = l_1_1 == 1u; + if l_1_15 { + l_1_16 = f32(input_0_global[l_1_11]); + l_1_17 = f32(0u); + l_1_18 = f32(0u); + l_1_19 = f32(0u); + } + } + } + } + l_4_0 = u32(0u); + l_1_9[l_4_0] = f32(l_1_16); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_17); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_18); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_19); + shared_memory_0[l_1_4] = vec4(l_1_9); + } else { + var l_4_0: u32; + l_1_16 = f32(0u); + l_4_0 = u32(0u); + l_1_9[l_4_0] = f32(l_1_16); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_16); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_16); + l_4_0 = l_4_0 + 1u; + l_1_9[l_4_0] = f32(l_1_16); + shared_memory_0[l_1_4] = vec4(l_1_9); + } + } + } + l_1_20 = l_0_4 - l_0_16; + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_1_21 = l_0_13 + l_2_0; + l_1_22 = l_1_21 < 32u; + if l_1_22 { + l_1_23 = l_1_21 * 64u; + l_1_23 = l_1_23 + l_0_14; + l_1_23 = l_1_23 / 4u; + l_1_25 = l_1_21 + l_0_32; + l_1_24 = l_1_25 < l_0_3; + l_1_26 = l_1_20 >= 1u; + l_1_27 = l_1_24 && l_1_26; + if l_1_27 { + var l_4_0: u32; + l_1_30 = l_0_32 + l_1_21; + l_1_30 = l_1_30 * l_0_7; + l_1_29 = l_0_14 * l_0_8; + l_1_30 = l_1_30 + l_1_29; + l_1_30 = l_1_30 + l_0_18; + l_1_31 = l_1_30 + l_0_8; + l_1_32 = l_1_31 + l_0_8; + l_1_33 = l_1_32 + l_0_8; + l_1_34 = l_1_20 >= 4u; + if l_1_34 { + l_1_35 = f32(input_1_global[l_1_30]); + l_1_36 = f32(input_1_global[l_1_31]); + l_1_37 = f32(input_1_global[l_1_32]); + l_1_38 = f32(input_1_global[l_1_33]); + } else { + l_1_34 = l_1_20 == 3u; + if l_1_34 { + l_1_35 = f32(input_1_global[l_1_30]); + l_1_36 = f32(input_1_global[l_1_31]); + l_1_37 = f32(input_1_global[l_1_32]); + l_1_38 = f32(0u); + } else { + l_1_34 = l_1_20 == 2u; + if l_1_34 { + l_1_35 = f32(input_1_global[l_1_30]); + l_1_36 = f32(input_1_global[l_1_31]); + l_1_37 = f32(0u); + l_1_38 = f32(0u); + } else { + l_1_34 = l_1_20 == 1u; + if l_1_34 { + l_1_35 = f32(input_1_global[l_1_30]); + l_1_36 = f32(0u); + l_1_37 = f32(0u); + l_1_38 = f32(0u); + } + } + } + } + l_4_0 = u32(0u); + l_1_28[l_4_0] = f32(l_1_35); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_36); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_37); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_38); + shared_memory_1[l_1_23] = vec4(l_1_28); + } else { + var l_4_0: u32; + l_1_35 = f32(0u); + l_4_0 = u32(0u); + l_1_28[l_4_0] = f32(l_1_35); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_35); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_35); + l_4_0 = l_4_0 + 1u; + l_1_28[l_4_0] = f32(l_1_35); + shared_memory_1[l_1_23] = vec4(l_1_28); + } + } + } + workgroupBarrier(); + + for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) { + l_1_39 = l_0_13 / 4u; + l_1_39 = l_1_39 * 32u; + l_1_39 = l_1_39 + l_2_0; + l_0_29 = vec4(shared_memory_0[l_1_39]); + l_1_40 = l_2_0 * 64u; + l_1_40 = l_1_40 + l_0_14; + l_1_40 = l_1_40 / 4u; + l_0_30 = vec4(shared_memory_1[l_1_40]); + + for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) { + + for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) { + l_1_41 = f32(l_0_29[l_3_0]); + l_1_42 = f32(l_0_30[l_4_0]); + l_1_43 = l_1_41 * l_1_42; + l_1_44 = l_3_0 * 4u; + l_1_44 = l_1_44 + l_4_0; + l_1_45 = f32(a_0_0[l_1_44]); + l_1_46 = l_1_45 + l_1_43; + a_0_0[l_1_44] = f32(l_1_46); + } + } + } + workgroupBarrier(); + } + + for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) { + + for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) { + l_0_36 = l_0_15 + l_1_0; + l_0_37 = l_0_16 + l_2_0; + l_0_38 = l_0_36 < l_0_2; + l_0_39 = l_0_37 < l_0_4; + l_0_38 = l_0_38 && l_0_39; + if l_0_38 { + var l_3_0: u32; + var l_3_1: f32; + var l_3_2: u32; + l_3_0 = l_1_0 * 4u; + l_3_0 = l_3_0 + l_2_0; + l_3_1 = f32(a_0_0[l_3_0]); + l_0_36 = l_0_36 * l_0_9; + l_0_37 = l_0_37 * l_0_10; + l_3_2 = l_0_36 + l_0_37; + l_3_2 = l_3_2 + l_0_19; + output_0_global[l_3_2] = f32(l_3_1); + } + } + } +} \ No newline at end of file diff --git a/crates/burn-wgpu/src/compute/server.rs b/crates/burn-wgpu/src/compute/server.rs index a03eb74e3..26ce1e195 100644 --- a/crates/burn-wgpu/src/compute/server.rs +++ b/crates/burn-wgpu/src/compute/server.rs @@ -89,6 +89,7 @@ where let compile = kernel.compile(); let pipeline = self.compile_source(&compile.source); + println!("{}", compile.source); self.pipelines.insert(kernel_id.clone(), pipeline.clone());