mirror of https://github.com/tracel-ai/burn.git
Fix lstm batch size bug (#1695)
This commit is contained in:
parent
ce2429eb10
commit
2f294c5092
|
@ -353,4 +353,17 @@ mod tests {
|
|||
// Asserts that the gradients exist and are non-zero
|
||||
assert!(*some_gradient.any().into_data().value.first().unwrap());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass_batch_of_one() {
|
||||
let device = Default::default();
|
||||
let lstm = LstmConfig::new(64, 1024, true).init(&device);
|
||||
let batched_input =
|
||||
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
|
||||
|
||||
let (cell_state, hidden_state) = lstm.forward(batched_input, None);
|
||||
|
||||
assert_eq!(cell_state.shape().dims, [1, 2, 1024]);
|
||||
assert_eq!(hidden_state.shape().dims, [1, 2, 1024]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -161,19 +161,19 @@ impl CompilationSettings {
|
|||
.zip(info.inputs.iter())
|
||||
.enumerate()
|
||||
.filter_map(|(pos, (desc, input))| {
|
||||
let handle = &handles_inputs[pos];
|
||||
|
||||
if !is_contiguous(&handle.strides) {
|
||||
return None;
|
||||
}
|
||||
|
||||
match desc.status {
|
||||
burn_tensor::repr::TensorStatus::ReadOnly => return None,
|
||||
burn_tensor::repr::TensorStatus::NotInit => return None,
|
||||
burn_tensor::repr::TensorStatus::ReadWrite => (),
|
||||
};
|
||||
|
||||
Some((pos, desc, input))
|
||||
let handle = &handles_inputs[pos];
|
||||
|
||||
if handle.handle.can_mut() && is_contiguous(&handle.strides) {
|
||||
Some((pos, desc, input))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
|
|
Loading…
Reference in New Issue