Fix lstm batch size bug (#1695)

This commit is contained in:
Nathaniel Simard 2024-04-26 08:54:12 -04:00 committed by GitHub
parent ce2429eb10
commit 2f294c5092
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 7 deletions

View File

@ -353,4 +353,17 @@ mod tests {
// Asserts that the gradients exist and are non-zero
assert!(*some_gradient.any().into_data().value.first().unwrap());
}
#[test]
fn test_batched_forward_pass_batch_of_one() {
let device = Default::default();
let lstm = LstmConfig::new(64, 1024, true).init(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
let (cell_state, hidden_state) = lstm.forward(batched_input, None);
assert_eq!(cell_state.shape().dims, [1, 2, 1024]);
assert_eq!(hidden_state.shape().dims, [1, 2, 1024]);
}
}

View File

@ -161,19 +161,19 @@ impl CompilationSettings {
.zip(info.inputs.iter())
.enumerate()
.filter_map(|(pos, (desc, input))| {
let handle = &handles_inputs[pos];
if !is_contiguous(&handle.strides) {
return None;
}
match desc.status {
burn_tensor::repr::TensorStatus::ReadOnly => return None,
burn_tensor::repr::TensorStatus::NotInit => return None,
burn_tensor::repr::TensorStatus::ReadWrite => (),
};
Some((pos, desc, input))
let handle = &handles_inputs[pos];
if handle.handle.can_mut() && is_contiguous(&handle.strides) {
Some((pos, desc, input))
} else {
None
}
})
.collect::<Vec<_>>();