mirror of https://github.com/tracel-ai/burn.git
wip
This commit is contained in:
parent
97b53624b9
commit
b2bd0fb313
|
@ -17,27 +17,18 @@ pub(crate) fn compute_loop<F: Float>(
|
||||||
) {
|
) {
|
||||||
let tile_size = Comptime::map(config, |c| c.tile_size);
|
let tile_size = Comptime::map(config, |c| c.tile_size);
|
||||||
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
let block_size_m = Comptime::map(config, |c| c.block_size_m);
|
||||||
let block_size_k = Comptime::map(config, |c| c.block_size_k);
|
let block_size_k = Comptime::runtime(Comptime::map(config, |c| c.block_size_k));
|
||||||
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
let block_size_n = Comptime::map(config, |c| c.block_size_n);
|
||||||
// let unroll = Comptime::map(config, |c| c.unroll);
|
let unroll = Comptime::map(config, |c| c.unroll);
|
||||||
|
|
||||||
let unit_row = coordinates.unit_row;
|
let unit_row = coordinates.unit_row;
|
||||||
let unit_col = coordinates.unit_col;
|
let unit_col = coordinates.unit_col;
|
||||||
|
|
||||||
for dot_index in range(0u32, Comptime::runtime(block_size_k), Comptime::new(false)) {
|
for dot_index in range(0u32, block_size_k, unroll) {
|
||||||
let lhs_index =
|
let register_m = shared_lhs[(unit_row + dot_index * Comptime::runtime(block_size_m))
|
||||||
(unit_row + dot_index * Comptime::runtime(block_size_m)) / Comptime::runtime(tile_size);
|
/ Comptime::runtime(tile_size)];
|
||||||
let mut register_m = F::vectorized(0., Comptime::get(tile_size));
|
let register_n = shared_rhs[(unit_col + dot_index * Comptime::runtime(block_size_n))
|
||||||
if lhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
/ Comptime::runtime(tile_size)];
|
||||||
register_m = shared_lhs[lhs_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
let rhs_index =
|
|
||||||
(unit_col + dot_index * Comptime::runtime(block_size_n)) / Comptime::runtime(tile_size);
|
|
||||||
let mut register_n = F::vectorized(0., Comptime::get(tile_size));
|
|
||||||
if rhs_index < Comptime::runtime(block_size_k * block_size_m / tile_size) {
|
|
||||||
register_n = shared_rhs[rhs_index];
|
|
||||||
}
|
|
||||||
|
|
||||||
tile_outer_product::<F>(register_m, register_n, results, config);
|
tile_outer_product::<F>(register_m, register_n, results, config);
|
||||||
}
|
}
|
||||||
|
@ -138,8 +129,8 @@ pub mod tests {
|
||||||
|
|
||||||
/// Exported test
|
/// Exported test
|
||||||
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
|
pub fn compute_loop_unit_offset_test<R: JitRuntime>(device: &R::Device) {
|
||||||
let lhs = range_tensor_transposed::<R>(8, 4, device);
|
let shared_lhs = range_tensor::<R>(4, 8, device);
|
||||||
let rhs = range_tensor::<R>(4, 8, device);
|
let shared_rhs = range_tensor::<R>(4, 8, device);
|
||||||
let results = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
let results = create_empty::<R>(TILE_SIZE, TILE_SIZE, device);
|
||||||
let cube_dim = CubeDim::new(1, 1, 1);
|
let cube_dim = CubeDim::new(1, 1, 1);
|
||||||
let cube_count = CubeCount::new(1, 1, 1);
|
let cube_count = CubeCount::new(1, 1, 1);
|
||||||
|
@ -152,11 +143,19 @@ pub mod tests {
|
||||||
let config = make_config(4, 8, 4);
|
let config = make_config(4, 8, 4);
|
||||||
|
|
||||||
compute_loop_test_launch::<F32, R>(
|
compute_loop_test_launch::<F32, R>(
|
||||||
lhs.client.clone(),
|
shared_lhs.client.clone(),
|
||||||
cube_count,
|
cube_count,
|
||||||
settings,
|
settings,
|
||||||
TensorHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
TensorHandle::new(
|
||||||
TensorHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
&shared_lhs.handle,
|
||||||
|
&shared_lhs.strides,
|
||||||
|
&shared_lhs.shape.dims,
|
||||||
|
),
|
||||||
|
TensorHandle::new(
|
||||||
|
&shared_rhs.handle,
|
||||||
|
&shared_rhs.strides,
|
||||||
|
&shared_rhs.shape.dims,
|
||||||
|
),
|
||||||
4,
|
4,
|
||||||
4,
|
4,
|
||||||
ArrayHandle::new(&results, 1),
|
ArrayHandle::new(&results, 1),
|
||||||
|
@ -164,8 +163,8 @@ pub mod tests {
|
||||||
);
|
);
|
||||||
|
|
||||||
let expected = &[
|
let expected = &[
|
||||||
1160.0, 1230.0, 1300.0, 1370.0, 1416.0, 1502.0, 1588.0, 1674.0, 1672.0, 1774.0, 1876.0,
|
1344.0, 1408.0, 1472.0, 1536.0, 1408.0, 1476.0, 1544.0, 1612.0, 1472.0, 1544.0, 1616.0,
|
||||||
1978.0, 1928.0, 2046.0, 2164.0, 2282.0,
|
1688.0, 1536.0, 1612.0, 1688.0, 1764.0,
|
||||||
];
|
];
|
||||||
assert_equals::<R>(results, expected, device);
|
assert_equals::<R>(results, expected, device);
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,10 +96,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
fn compile_source(&self, source: &str) -> Arc<ComputePipeline> {
|
||||||
unsafe {
|
let module = self.device.create_shader_module(ShaderModuleDescriptor {
|
||||||
let module = self
|
|
||||||
.device
|
|
||||||
.create_shader_module_unchecked(ShaderModuleDescriptor {
|
|
||||||
label: None,
|
label: None,
|
||||||
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)),
|
||||||
});
|
});
|
||||||
|
@ -115,7 +112,6 @@ where
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
|
fn buffer_reader(&mut self, handle: server::Binding<Self>) -> BufferReader {
|
||||||
let resource = self.memory_management.get(handle.memory);
|
let resource = self.memory_management.get(handle.memory);
|
||||||
|
|
Loading…
Reference in New Issue