mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
97b53624b9
commit
b2bd0fb313
|
@ -17,27 +17,18 @@ pub(crate) fn compute_loop<F: Float>(
|
|||
) {
|
||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
||||
let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||
// let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
|
||||
let unit_row = coordinates.unit_row;
|
||||
let unit_col = coordinates.unit_col;
|
||||
|
||||
for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) {
|
||||
let lhs_index =
|
||||
(unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size);
|
||||
let mut register_m = F::vectorized(0., Comptime::get(tile_size));
|
||||
if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
||||
register_m = shared_lhs[lhs_index];
|
||||
}
|
||||
|
||||
let rhs_index =
|
||||
(unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size);
|
||||
let mut register_n = F::vectorized(0., Comptime::get(tile_size));
|
||||
if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
||||
register_n = shared_rhs[rhs_index];
|
||||
}
|
||||
for dot_index in range(0u32, block_size_k, unroll) {
|
||||
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
|
||||
/ Comptime::runtime(tile_size)];
|
||||
|
||||
tile_outer_product::<F>(register_m, register_n, results, config);
|
||||
}
|
||||
|
@ -138,8 +129,8 @@ pub mod tests {
|
|||
|
||||
/// Exported test
|
||||
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
|
||||
let lhs = range_tensor_transposed::<R>(8, 4, device);
|
||||
let rhs = range_tensor::<R>(4, 8, device);
|
||||
let shared_lhs = range_tensor::<R>(4, 8, device);
|
||||
let shared_rhs = range_tensor::<R>(4, 8, device);
|
||||
let results = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||
let cube_dim = CubeDim::new(1, 1, 1);
|
||||
let cube_count = CubeCount::new(1, 1, 1);
|
||||
|
@ -152,11 +143,19 @@ pub mod tests {
|
|||
let config = make_config(4, 8, 4);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
shared_lhs.client.clone(),
|
||||
cube_count,
|
||||
settings,
|
||||
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
TensorHandle::new(
|
||||
&shared_lhs.handle,
|
||||
&shared_lhs.strides,
|
||||
&shared_lhs.shape.dims,
|
||||
),
|
||||
TensorHandle::new(
|
||||
&shared_rhs.handle,
|
||||
&shared_rhs.strides,
|
||||
&shared_rhs.shape.dims,
|
||||
),
|
||||
4,
|
||||
4,
|
||||
ArrayHandle::new(&results, 1),
|
||||
|
@ -164,8 +163,8 @@ pub mod tests {
|
|||
);
|
||||
|
||||
let expected = &[
|
||||
1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0,
|
||||
1978.0, 1928.0, 2046.0, 2164.0, 2282.0,
|
||||
1344.0, 1408.0, 1472.0, 1536.0, 1408.0, 1476.0, 1544.0, 1612.0, 1472.0, 1544.0, 1616.0,
|
||||
1688.0, 1536.0, 1612.0, 1688.0, 1764.0,
|
||||
];
|
||||
assert_equals::<R>(results, expected, device);
|
||||
}
|
||||
|
|
|
@ -96,25 +96,21 @@ where
|
|||
}
|
||||
|
||||
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
||||
unsafe {
|
||||
let module = self
|
||||
.device
|
||||
.create_shader_module_unchecked(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
let module = self.device.create_shader_module(ShaderModuleDescriptor {
|
||||
label: None,
|
||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||
});
|
||||
|
||||
Arc::new(
|
||||
self.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
compilation_options: Default::default(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
Arc::new(
|
||||
self.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: None,
|
||||
module: &module,
|
||||
entry_point: "main",
|
||||
compilation_options: Default::default(),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
|
||||
|
|
Loading…
Reference in New Issue