fix unroll bug

This commit is contained in:
louisfd 2024-06-27 16:59:50 -04:00
parent 71ea9fb415
commit 081fd782af
1 changed files with 2 additions and 4 deletions

View File

@ -24,7 +24,7 @@ fn tiling2d_cube<F: Float>(
config: Comptime<CubeTiling2dConfig>,
) {
let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config);
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z, config);
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z);
let shared_memories = make_shared_memories::<F>(config);
tiling2d_core(lhs, rhs, out, coordinates, offsets, shared_memories, config);
}
@ -90,9 +90,7 @@ fn calculate_batch_offsets<F: Float>(
rhs: &Tensor<F>,
out: &Tensor<F>,
batch_number: UInt,
config: Comptime<CubeTiling2dConfig>,
) -> BatchOffsets {
let unroll = Comptime::map(config, |c| c.unroll);
let rank = out.rank();
let dim_m = lhs.shape(rank - UInt::new(2));
@ -104,7 +102,7 @@ fn calculate_batch_offsets<F: Float>(
let mut offset_rhs = UInt::new(0);
// Batch offset for lhs, rhs
for b in range(0u32, rank - UInt::new(2), unroll) {
for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) {
let tmp = offset_out / out.stride(b);
offset_lhs += tmp % lhs.shape(b) * lhs.stride(b);
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);