mirror of https://github.com/tracel-ai/burn.git
fix unroll bug
This commit is contained in:
parent
71ea9fb415
commit
081fd782af
|
@ -24,7 +24,7 @@ fn tiling2d_cube<F: Float>(
|
|||
config: Comptime<CubeTiling2dConfig>,
|
||||
) {
|
||||
let coordinates = calculate_coordinates(CUBE_POS_X, CUBE_POS_Y, UNIT_POS, config);
|
||||
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z, config);
|
||||
let offsets = calculate_batch_offsets::<F>(lhs, rhs, out, CUBE_POS_Z);
|
||||
let shared_memories = make_shared_memories::<F>(config);
|
||||
tiling2d_core(lhs, rhs, out, coordinates, offsets, shared_memories, config);
|
||||
}
|
||||
|
@ -90,9 +90,7 @@ fn calculate_batch_offsets<F: Float>(
|
|||
rhs: &Tensor<F>,
|
||||
out: &Tensor<F>,
|
||||
batch_number: UInt,
|
||||
config: Comptime<CubeTiling2dConfig>,
|
||||
) -> BatchOffsets {
|
||||
let unroll = Comptime::map(config, |c| c.unroll);
|
||||
let rank = out.rank();
|
||||
|
||||
let dim_m = lhs.shape(rank - UInt::new(2));
|
||||
|
@ -104,7 +102,7 @@ fn calculate_batch_offsets<F: Float>(
|
|||
let mut offset_rhs = UInt::new(0);
|
||||
|
||||
// Batch offset for lhs, rhs
|
||||
for b in range(0u32, rank - UInt::new(2), unroll) {
|
||||
for b in range(0u32, rank - UInt::new(2), Comptime::new(false)) {
|
||||
let tmp = offset_out / out.stride(b);
|
||||
offset_lhs += tmp % lhs.shape(b) * lhs.stride(b);
|
||||
offset_rhs += tmp % rhs.shape(b) * rhs.stride(b);
|
||||
|
|
Loading…
Reference in New Issue