This commit is contained in:
louisfd 2024-06-28 14:19:21 -04:00
parent 97b53624b9
commit b2bd0fb313
2 changed files with 36 additions and 41 deletions

View File

@ -17,27 +17,18 @@ pub(crate) fn compute_loop<F: Float>(
) {
let tile_size = Comptime::map(config, |c| c.tile_size);
let block_size_m = Comptime::map(config, |c| c.block_size_m);
let block_size_k = Comptime::map(config, |c| c.block_size_k);
let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
let block_size_n = Comptime::map(config, |c| c.block_size_n);
// let unroll = Comptime::map(config, |c| c.unroll);
let unroll = Comptime::map(config, |c| c.unroll);
let unit_row = coordinates.unit_row;
let unit_col = coordinates.unit_col;
for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) {
let lhs_index =
(unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size);
let mut register_m = F::vectorized(0., Comptime::get(tile_size));
if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
register_m = shared_lhs[lhs_index];
}
let rhs_index =
(unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size);
let mut register_n = F::vectorized(0., Comptime::get(tile_size));
if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
register_n = shared_rhs[rhs_index];
}
for dot_index in range(0u32, block_size_k, unroll) {
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
/ Comptime::runtime(tile_size)];
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
/ Comptime::runtime(tile_size)];
tile_outer_product::<F>(register_m, register_n, results, config);
}
@ -138,8 +129,8 @@ pub mod tests {
/// Exported test
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
let lhs = range_tensor_transposed::<R>(8, 4, device);
let rhs = range_tensor::<R>(4, 8, device);
let shared_lhs = range_tensor::<R>(4, 8, device);
let shared_rhs = range_tensor::<R>(4, 8, device);
let results = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
let cube_dim = CubeDim::new(1, 1, 1);
let cube_count = CubeCount::new(1, 1, 1);
@ -152,11 +143,19 @@ pub mod tests {
let config = make_config(4, 8, 4);
compute_loop_test_launch::<F32, R>(
lhs.client.clone(),
shared_lhs.client.clone(),
cube_count,
settings,
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
TensorHandle::new(
&shared_lhs.handle,
&shared_lhs.strides,
&shared_lhs.shape.dims,
),
TensorHandle::new(
&shared_rhs.handle,
&shared_rhs.strides,
&shared_rhs.shape.dims,
),
4,
4,
ArrayHandle::new(&results, 1),
@ -164,8 +163,8 @@ pub mod tests {
);
let expected = &[
1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0,
1978.0, 1928.0, 2046.0, 2164.0, 2282.0,
1344.0, 1408.0, 1472.0, 1536.0, 1408.0, 1476.0, 1544.0, 1612.0, 1472.0, 1544.0, 1616.0,
1688.0, 1536.0, 1612.0, 1688.0, 1764.0,
];
assert_equals::<R>(results, expected, device);
}

View File

@ -96,10 +96,7 @@ where
}
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
unsafe {
let module = self
.device
.create_shader_module_unchecked(ShaderModuleDescriptor {
let module = self.device.create_shader_module(ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
});
@ -115,7 +112,6 @@ where
}),
)
}
}
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
let resource = self.memory_management.get(handle.memory);