This commit is contained in:
louisfd 2024-06-29 13:01:39 -04:00
parent 61ca9ff0b6
commit e09d60fff0
5 changed files with 633 additions and 7 deletions

View File

@ -170,9 +170,13 @@ fn write_tile_plain<F: Float>(
tile_size: Comptime<UInt>,
) {
let sm_vectorization = Comptime::runtime(tile_size);
let sm_len = UInt::new(2048) / sm_vectorization;
for i in range(0u32, Comptime::get(tile_size), unroll) {
shared_memory[(sm_position_base + i * sm_stride) / sm_vectorization] = tile[i];
let sm_pos = (sm_position_base + i * sm_stride) / sm_vectorization;
if sm_pos < sm_len {
shared_memory[sm_pos] = tile[i];
}
}
}
@ -185,11 +189,14 @@ fn write_tile_transposed<F: Float>(
unroll: Comptime<bool>,
tile_size: Comptime<UInt>,
) {
let is_scalar = Comptime::map(tile_size, |c| c.val == 1);
let sm_is_scalar = Comptime::map(tile_size, |c| c.val == 1);
let sm_vectorization = Comptime::runtime(tile_size);
let sm_len = UInt::new(2048) / sm_vectorization;
if Comptime::get(is_scalar) {
shared_memory[sm_position_base] = tile[0];
if Comptime::get(sm_is_scalar) {
if sm_position_base < sm_len {
shared_memory[sm_position_base] = tile[0];
}
} else {
for i in range(0u32, Comptime::get(tile_size), unroll) {
let mut transposed = F::vectorized(0., Comptime::get(tile_size));
@ -199,8 +206,10 @@ fn write_tile_transposed<F: Float>(
transposed[j] = tile[j][i];
}
let sm_position = (sm_position_base + i * sm_stride) / sm_vectorization;
shared_memory[sm_position] = transposed;
let sm_pos = (sm_position_base + i * sm_stride) / sm_vectorization;
if sm_pos < sm_len {
shared_memory[sm_pos] = transposed;
}
}
}
}

View File

@ -0,0 +1,236 @@
@group(0)
@binding(0)
var<storage, read_write> input_0_global: array<vec4<f32>>;
@group(0)
@binding(1)
var<storage, read_write> input_1_global: array<vec4<f32>>;
@group(0)
@binding(2)
var<storage, read_write> output_0_global: array<vec4<f32>>;
@group(0)
@binding(3)
var<storage, read_write> info: array<u32>;
var<workgroup> shared_memory_0: array<vec4<f32>, 512>;
var<workgroup> shared_memory_1: array<vec4<f32>, 512>;
const WORKGROUP_SIZE_X = 16u;
const WORKGROUP_SIZE_Y = 16u;
const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(16, 16, 1)
fn main(
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
var a_0_0: array<f32, 16>;
var a_0_1: array<vec4<f32>, 4>;
var a_0_2: array<vec4<f32>, 4>;
let rank: u32 = info[0];
let rank_2: u32 = rank * 2u;
var l_0_0: u32;
var l_0_1: u32;
var l_0_2: u32;
var l_0_3: u32;
var l_0_4: u32;
var l_0_5: u32;
var l_0_6: u32;
var l_0_7: u32;
var l_0_8: u32;
var l_0_9: u32;
var l_0_10: u32;
var l_0_11: u32;
var l_0_12: u32;
var l_0_13: u32;
var l_0_14: vec4<f32>;
var l_0_15: vec4<f32>;
var l_0_16: f32;
var l_0_17: f32;
l_0_0 = 64u - 1u;
l_0_0 = l_0_0 / 4u;
l_0_0 = l_0_0 + 1u;
l_0_1 = workgroup_id.x * 64u;
l_0_2 = workgroup_id.y * 64u;
l_0_3 = local_idx / l_0_0;
l_0_3 = l_0_3 * 4u;
l_0_0 = local_idx % l_0_0;
l_0_0 = l_0_0 * 4u;
l_0_4 = u32(rank);
l_0_5 = l_0_4 - 2u;
l_0_6 = info[(0u * rank_2) + rank + l_0_5 + 1u];
l_0_5 = l_0_4 - 1u;
l_0_7 = info[(1u * rank_2) + rank + l_0_5 + 1u];
l_0_6 = l_0_6 * l_0_7;
l_0_6 = l_0_6 * workgroup_id.z;
l_0_7 = u32(0u);
l_0_5 = u32(0u);
l_0_4 = l_0_4 - 2u;
for (var l_1_0: u32 = 0u; l_1_0 < l_0_4; l_1_0++) {
l_0_8 = info[(2u * rank_2) + l_1_0 + 1u];
l_0_8 = l_0_6 / l_0_8;
l_0_9 = info[(0u * rank_2) + rank + l_1_0 + 1u];
l_0_9 = l_0_8 % l_0_9;
l_0_10 = info[(0u * rank_2) + l_1_0 + 1u];
l_0_9 = l_0_9 * l_0_10;
l_0_7 = l_0_7 + l_0_9;
l_0_10 = info[(1u * rank_2) + rank + l_1_0 + 1u];
l_0_8 = l_0_8 % l_0_10;
l_0_10 = info[(1u * rank_2) + l_1_0 + 1u];
l_0_8 = l_0_8 * l_0_10;
l_0_5 = l_0_5 + l_0_8;
}
for (var l_1_0: u32 = 0u; l_1_0 < 16u; l_1_0++) {
a_0_0[l_1_0] = f32(0f);
}
l_0_10 = rank - 1u;
l_0_9 = info[(0u * rank_2) + rank + l_0_10 + 1u];
l_0_10 = u32(0u);
l_0_8 = l_0_9 / 32u;
l_0_10 = u32(l_0_8);
for (var l_1_0: u32 = 0u; l_1_0 < l_0_10; l_1_0++) {
l_0_9 = l_1_0 * 32u;
l_0_8 = l_0_0 * 64u;
l_0_8 = l_0_8 + l_0_3;
l_0_4 = rank - 2u;
l_0_11 = info[(0u * rank_2) + l_0_4 + 1u];
l_0_11 = l_0_1 * l_0_11;
l_0_11 = l_0_11 + l_0_9;
l_0_11 = l_0_11 + l_0_7;
l_0_4 = u32(rank);
l_0_4 = l_0_4 - 2u;
l_0_12 = info[(0u * rank_2) + l_0_4 + 1u];
l_0_4 = l_0_3 * l_0_12;
l_0_4 = l_0_4 + l_0_0;
l_0_4 = l_0_4 + l_0_11;
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_0_11 = l_2_0 * l_0_12;
l_0_11 = l_0_4 + l_0_11;
l_0_13 = l_0_11 / 4u;
l_0_14 = vec4<f32>(input_0_global[l_0_13]);
a_0_1[l_2_0] = vec4<f32>(l_0_14);
}
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_0_14[0u] = f32(0f);
l_0_14[1u] = f32(0f);
l_0_14[2u] = f32(0f);
l_0_14[3u] = f32(0f);
l_0_15 = vec4<f32>(a_0_1[0u]);
l_0_16 = f32(l_0_15[l_2_0]);
l_0_14[0u] = f32(l_0_16);
l_0_15 = vec4<f32>(a_0_1[1u]);
l_0_16 = f32(l_0_15[l_2_0]);
l_0_14[1u] = f32(l_0_16);
l_0_15 = vec4<f32>(a_0_1[2u]);
l_0_16 = f32(l_0_15[l_2_0]);
l_0_14[2u] = f32(l_0_16);
l_0_15 = vec4<f32>(a_0_1[3u]);
l_0_16 = f32(l_0_15[l_2_0]);
l_0_14[3u] = f32(l_0_16);
l_0_13 = l_2_0 * 64u;
l_0_13 = l_0_8 + l_0_13;
l_0_13 = l_0_13 / 4u;
shared_memory_0[l_0_13] = vec4<f32>(l_0_14);
}
l_0_13 = rank - 2u;
l_0_12 = info[(1u * rank_2) + l_0_13 + 1u];
l_0_13 = l_0_3 * 64u;
l_0_13 = l_0_13 + l_0_0;
l_0_12 = l_0_9 * l_0_12;
l_0_12 = l_0_2 + l_0_12;
l_0_12 = l_0_12 + l_0_5;
l_0_11 = u32(rank);
l_0_11 = l_0_11 - 2u;
l_0_8 = info[(1u * rank_2) + l_0_11 + 1u];
l_0_11 = l_0_3 * l_0_8;
l_0_11 = l_0_11 + l_0_0;
l_0_11 = l_0_11 + l_0_12;
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_0_12 = l_2_0 * l_0_8;
l_0_12 = l_0_11 + l_0_12;
l_0_4 = l_0_12 / 4u;
l_0_15 = vec4<f32>(input_1_global[l_0_4]);
a_0_2[l_2_0] = vec4<f32>(l_0_15);
}
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_0_12 = l_2_0 * 64u;
l_0_12 = l_0_13 + l_0_12;
l_0_12 = l_0_12 / 4u;
l_0_15 = vec4<f32>(a_0_2[l_2_0]);
shared_memory_1[l_0_12] = vec4<f32>(l_0_15);
}
workgroupBarrier();
l_0_13 = u32(l_0_3);
l_0_12 = u32(l_0_0);
for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) {
l_0_11 = l_2_0 * 64u;
l_0_11 = l_0_13 + l_0_11;
l_0_11 = l_0_11 / 4u;
l_0_15 = vec4<f32>(shared_memory_0[l_0_11]);
l_0_11 = l_2_0 * 64u;
l_0_11 = l_0_12 + l_0_11;
l_0_11 = l_0_11 / 4u;
l_0_14 = vec4<f32>(shared_memory_1[l_0_11]);
for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) {
l_0_11 = l_3_0 * 4u;
for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) {
l_0_16 = f32(l_0_15[l_3_0]);
l_0_17 = f32(l_0_14[l_4_0]);
l_0_16 = l_0_16 * l_0_17;
l_0_9 = l_0_11 + l_4_0;
l_0_17 = f32(a_0_0[l_0_9]);
l_0_17 = l_0_17 + l_0_16;
a_0_0[l_0_9] = f32(l_0_17);
}
}
}
workgroupBarrier();
}
l_0_13 = l_0_1 + l_0_3;
l_0_0 = l_0_2 + l_0_0;
l_0_12 = rank - 2u;
l_0_11 = info[(2u * rank_2) + l_0_12 + 1u];
for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) {
l_0_12 = l_1_0 * 4u;
l_0_10 = l_0_13 + l_1_0;
l_0_10 = l_0_10 * l_0_11;
l_0_10 = l_0_10 + l_0_0;
l_0_10 = l_0_10 + l_0_6;
for (var l_2_0: u32 = 0u; l_2_0 < 1u; l_2_0++) {
l_0_15[0u] = f32(0f);
l_0_15[1u] = f32(0f);
l_0_15[2u] = f32(0f);
l_0_15[3u] = f32(0f);
for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) {
l_0_9 = l_2_0 * 4u;
l_0_9 = l_0_9 + l_3_0;
l_0_9 = l_0_12 + l_0_9;
l_0_17 = f32(a_0_0[l_0_9]);
l_0_15[l_3_0] = f32(l_0_17);
}
l_0_9 = l_0_10 / 4u;
l_0_9 = l_2_0 + l_0_9;
output_0_global[l_0_9] = vec4<f32>(l_0_15);
}
}
}

View File

@ -432,7 +432,7 @@ mod tests {
#[test]
pub fn k_exceeds_block() {
test_with_params(64, 36, 32, 1, 1);
test_with_params(64, 32, 64, 1, 1);
}
#[test]

View File

@ -0,0 +1,380 @@
@group(0)
@binding(0)
var<storage, read_write> input_0_global: array<f32>;
@group(0)
@binding(1)
var<storage, read_write> input_1_global: array<f32>;
@group(0)
@binding(2)
var<storage, read_write> output_0_global: array<f32>;
@group(0)
@binding(3)
var<storage, read_write> info: array<u32>;
var<workgroup> shared_memory_0: array<vec4<f32>, 512>;
var<workgroup> shared_memory_1: array<vec4<f32>, 512>;
const WORKGROUP_SIZE_X = 16u;
const WORKGROUP_SIZE_Y = 16u;
const WORKGROUP_SIZE_Z = 1u;
@compute
@workgroup_size(16, 16, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
var a_0_0: array<f32, 16>;
let rank: u32 = info[0];
let rank_2: u32 = rank * 2u;
var l_0_0: u32;
var l_0_1: u32;
var l_0_2: u32;
var l_0_3: u32;
var l_0_4: u32;
var l_0_5: u32;
var l_0_6: u32;
var l_0_7: u32;
var l_0_8: u32;
var l_0_9: u32;
var l_0_10: u32;
var l_0_11: u32;
var l_0_12: u32;
var l_0_13: u32;
var l_0_14: u32;
var l_0_15: u32;
var l_0_16: u32;
var l_0_17: u32;
var l_0_18: u32;
var l_0_19: u32;
var l_0_20: u32;
var l_0_21: u32;
var l_0_22: u32;
var l_0_23: u32;
var l_0_24: u32;
var l_0_25: u32;
var l_0_26: u32;
var l_0_27: u32;
var l_0_28: u32;
var l_0_29: vec4<f32>;
var l_0_30: vec4<f32>;
var l_0_31: u32;
var l_0_32: u32;
var l_0_33: f32;
var l_0_34: f32;
var l_0_35: f32;
var l_0_36: u32;
var l_0_37: u32;
var l_0_38: bool;
var l_0_39: bool;
l_0_0 = rank - 1u;
l_0_1 = rank - 2u;
l_0_2 = info[(0u * rank_2) + rank + l_0_1 + 1u];
l_0_3 = info[(0u * rank_2) + rank + l_0_0 + 1u];
l_0_4 = info[(1u * rank_2) + rank + l_0_0 + 1u];
l_0_5 = info[(0u * rank_2) + l_0_1 + 1u];
l_0_6 = info[(0u * rank_2) + l_0_0 + 1u];
l_0_7 = info[(1u * rank_2) + l_0_1 + 1u];
l_0_8 = info[(1u * rank_2) + l_0_0 + 1u];
l_0_9 = info[(2u * rank_2) + l_0_1 + 1u];
l_0_10 = info[(2u * rank_2) + l_0_0 + 1u];
l_0_11 = u32(workgroup_id.x);
l_0_11 = l_0_11 * 64u;
l_0_12 = u32(workgroup_id.y);
l_0_12 = l_0_12 * 64u;
l_0_13 = local_idx / 16u;
l_0_13 = l_0_13 * 4u;
l_0_14 = local_idx % 16u;
l_0_14 = l_0_14 * 4u;
l_0_15 = l_0_11 + l_0_13;
l_0_16 = l_0_12 + l_0_14;
l_0_17 = l_0_11 * l_0_5;
l_0_18 = l_0_12 * l_0_8;
l_0_19 = l_0_2 * l_0_4;
l_0_19 = l_0_19 * global_id.z;
l_0_20 = rank - 2u;
for (var l_1_0: u32 = 0u; l_1_0 < l_0_20; l_1_0++) {
l_0_21 = info[(0u * rank_2) + l_1_0 + 1u];
l_0_22 = info[(1u * rank_2) + l_1_0 + 1u];
l_0_23 = info[(2u * rank_2) + l_1_0 + 1u];
l_0_24 = info[(0u * rank_2) + rank + l_1_0 + 1u];
l_0_25 = info[(1u * rank_2) + rank + l_1_0 + 1u];
l_0_26 = l_0_19 / l_0_23;
l_0_27 = l_0_26 % l_0_24;
l_0_27 = l_0_27 * l_0_21;
l_0_17 = l_0_17 + l_0_27;
l_0_28 = l_0_26 % l_0_25;
l_0_28 = l_0_28 * l_0_22;
l_0_18 = l_0_18 + l_0_28;
}
l_0_33 = f32(l_0_3);
l_0_34 = f32(32u);
l_0_35 = l_0_33 / l_0_34;
l_0_35 = ceil(l_0_35);
l_0_31 = u32(l_0_35);
for (var l_1_0: u32 = 0u; l_1_0 < l_0_31; l_1_0++) {
var l_1_1: u32;
var l_1_2: u32;
var l_1_3: bool;
var l_1_4: u32;
var l_1_5: bool;
var l_1_6: u32;
var l_1_7: bool;
var l_1_8: bool;
var l_1_9: vec4<f32>;
var l_1_10: u32;
var l_1_11: u32;
var l_1_12: u32;
var l_1_13: u32;
var l_1_14: u32;
var l_1_15: bool;
var l_1_16: f32;
var l_1_17: f32;
var l_1_18: f32;
var l_1_19: f32;
var l_1_20: u32;
var l_1_21: u32;
var l_1_22: bool;
var l_1_23: u32;
var l_1_24: bool;
var l_1_25: u32;
var l_1_26: bool;
var l_1_27: bool;
var l_1_28: vec4<f32>;
var l_1_29: u32;
var l_1_30: u32;
var l_1_31: u32;
var l_1_32: u32;
var l_1_33: u32;
var l_1_34: bool;
var l_1_35: f32;
var l_1_36: f32;
var l_1_37: f32;
var l_1_38: f32;
var l_1_39: u32;
var l_1_40: u32;
var l_1_41: f32;
var l_1_42: f32;
var l_1_43: f32;
var l_1_44: u32;
var l_1_45: f32;
var l_1_46: f32;
l_0_32 = l_1_0 * 32u;
l_1_1 = l_0_2 - l_0_15;
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_1_2 = l_0_14 + l_2_0;
l_1_3 = l_1_2 < 32u;
if l_1_3 {
l_1_4 = l_0_13 / 4u;
l_1_4 = l_1_4 * 32u;
l_1_4 = l_1_4 + l_1_2;
l_1_6 = l_1_2 + l_0_32;
l_1_5 = l_1_6 < l_0_3;
l_1_7 = l_1_1 >= 1u;
l_1_8 = l_1_5 && l_1_7;
if l_1_8 {
var l_4_0: u32;
l_1_11 = l_0_32 + l_1_2;
l_1_11 = l_1_11 * l_0_6;
l_1_10 = l_0_13 * l_0_5;
l_1_11 = l_1_11 + l_1_10;
l_1_11 = l_1_11 + l_0_17;
l_1_12 = l_1_11 + l_0_5;
l_1_13 = l_1_12 + l_0_5;
l_1_14 = l_1_13 + l_0_5;
l_1_15 = l_1_1 >= 4u;
if l_1_15 {
l_1_16 = f32(input_0_global[l_1_11]);
l_1_17 = f32(input_0_global[l_1_12]);
l_1_18 = f32(input_0_global[l_1_13]);
l_1_19 = f32(input_0_global[l_1_14]);
} else {
l_1_15 = l_1_1 == 3u;
if l_1_15 {
l_1_16 = f32(input_0_global[l_1_11]);
l_1_17 = f32(input_0_global[l_1_12]);
l_1_18 = f32(input_0_global[l_1_13]);
l_1_19 = f32(0u);
} else {
l_1_15 = l_1_1 == 2u;
if l_1_15 {
l_1_16 = f32(input_0_global[l_1_11]);
l_1_17 = f32(input_0_global[l_1_12]);
l_1_18 = f32(0u);
l_1_19 = f32(0u);
} else {
l_1_15 = l_1_1 == 1u;
if l_1_15 {
l_1_16 = f32(input_0_global[l_1_11]);
l_1_17 = f32(0u);
l_1_18 = f32(0u);
l_1_19 = f32(0u);
}
}
}
}
l_4_0 = u32(0u);
l_1_9[l_4_0] = f32(l_1_16);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_17);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_18);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_19);
shared_memory_0[l_1_4] = vec4<f32>(l_1_9);
} else {
var l_4_0: u32;
l_1_16 = f32(0u);
l_4_0 = u32(0u);
l_1_9[l_4_0] = f32(l_1_16);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_16);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_16);
l_4_0 = l_4_0 + 1u;
l_1_9[l_4_0] = f32(l_1_16);
shared_memory_0[l_1_4] = vec4<f32>(l_1_9);
}
}
}
l_1_20 = l_0_4 - l_0_16;
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_1_21 = l_0_13 + l_2_0;
l_1_22 = l_1_21 < 32u;
if l_1_22 {
l_1_23 = l_1_21 * 64u;
l_1_23 = l_1_23 + l_0_14;
l_1_23 = l_1_23 / 4u;
l_1_25 = l_1_21 + l_0_32;
l_1_24 = l_1_25 < l_0_3;
l_1_26 = l_1_20 >= 1u;
l_1_27 = l_1_24 && l_1_26;
if l_1_27 {
var l_4_0: u32;
l_1_30 = l_0_32 + l_1_21;
l_1_30 = l_1_30 * l_0_7;
l_1_29 = l_0_14 * l_0_8;
l_1_30 = l_1_30 + l_1_29;
l_1_30 = l_1_30 + l_0_18;
l_1_31 = l_1_30 + l_0_8;
l_1_32 = l_1_31 + l_0_8;
l_1_33 = l_1_32 + l_0_8;
l_1_34 = l_1_20 >= 4u;
if l_1_34 {
l_1_35 = f32(input_1_global[l_1_30]);
l_1_36 = f32(input_1_global[l_1_31]);
l_1_37 = f32(input_1_global[l_1_32]);
l_1_38 = f32(input_1_global[l_1_33]);
} else {
l_1_34 = l_1_20 == 3u;
if l_1_34 {
l_1_35 = f32(input_1_global[l_1_30]);
l_1_36 = f32(input_1_global[l_1_31]);
l_1_37 = f32(input_1_global[l_1_32]);
l_1_38 = f32(0u);
} else {
l_1_34 = l_1_20 == 2u;
if l_1_34 {
l_1_35 = f32(input_1_global[l_1_30]);
l_1_36 = f32(input_1_global[l_1_31]);
l_1_37 = f32(0u);
l_1_38 = f32(0u);
} else {
l_1_34 = l_1_20 == 1u;
if l_1_34 {
l_1_35 = f32(input_1_global[l_1_30]);
l_1_36 = f32(0u);
l_1_37 = f32(0u);
l_1_38 = f32(0u);
}
}
}
}
l_4_0 = u32(0u);
l_1_28[l_4_0] = f32(l_1_35);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_36);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_37);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_38);
shared_memory_1[l_1_23] = vec4<f32>(l_1_28);
} else {
var l_4_0: u32;
l_1_35 = f32(0u);
l_4_0 = u32(0u);
l_1_28[l_4_0] = f32(l_1_35);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_35);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_35);
l_4_0 = l_4_0 + 1u;
l_1_28[l_4_0] = f32(l_1_35);
shared_memory_1[l_1_23] = vec4<f32>(l_1_28);
}
}
}
workgroupBarrier();
for (var l_2_0: u32 = 0u; l_2_0 < 32u; l_2_0++) {
l_1_39 = l_0_13 / 4u;
l_1_39 = l_1_39 * 32u;
l_1_39 = l_1_39 + l_2_0;
l_0_29 = vec4<f32>(shared_memory_0[l_1_39]);
l_1_40 = l_2_0 * 64u;
l_1_40 = l_1_40 + l_0_14;
l_1_40 = l_1_40 / 4u;
l_0_30 = vec4<f32>(shared_memory_1[l_1_40]);
for (var l_3_0: u32 = 0u; l_3_0 < 4u; l_3_0++) {
for (var l_4_0: u32 = 0u; l_4_0 < 4u; l_4_0++) {
l_1_41 = f32(l_0_29[l_3_0]);
l_1_42 = f32(l_0_30[l_4_0]);
l_1_43 = l_1_41 * l_1_42;
l_1_44 = l_3_0 * 4u;
l_1_44 = l_1_44 + l_4_0;
l_1_45 = f32(a_0_0[l_1_44]);
l_1_46 = l_1_45 + l_1_43;
a_0_0[l_1_44] = f32(l_1_46);
}
}
}
workgroupBarrier();
}
for (var l_1_0: u32 = 0u; l_1_0 < 4u; l_1_0++) {
for (var l_2_0: u32 = 0u; l_2_0 < 4u; l_2_0++) {
l_0_36 = l_0_15 + l_1_0;
l_0_37 = l_0_16 + l_2_0;
l_0_38 = l_0_36 < l_0_2;
l_0_39 = l_0_37 < l_0_4;
l_0_38 = l_0_38 && l_0_39;
if l_0_38 {
var l_3_0: u32;
var l_3_1: f32;
var l_3_2: u32;
l_3_0 = l_1_0 * 4u;
l_3_0 = l_3_0 + l_2_0;
l_3_1 = f32(a_0_0[l_3_0]);
l_0_36 = l_0_36 * l_0_9;
l_0_37 = l_0_37 * l_0_10;
l_3_2 = l_0_36 + l_0_37;
l_3_2 = l_3_2 + l_0_19;
output_0_global[l_3_2] = f32(l_3_1);
}
}
}
}

View File

@ -89,6 +89,7 @@ where
let compile = kernel.compile();
let pipeline = self.compile_source(&compile.source);
println!("{}", compile.source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());