mirror of https://github.com/tracel-ai/burn.git
Fix lstm batch size bug (#1695)
This commit is contained in:
parent
ce2429eb10
commit
2f294c5092
|
@ -353,4 +353,17 @@ mod tests {
|
||||||
// Asserts that the gradients exist and are non-zero
|
// Asserts that the gradients exist and are non-zero
|
||||||
assert!(*some_gradient.any().into_data().value.first().unwrap());
|
assert!(*some_gradient.any().into_data().value.first().unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_batched_forward_pass_batch_of_one() {
|
||||||
|
let device = Default::default();
|
||||||
|
let lstm = LstmConfig::new(64, 1024, true).init(&device);
|
||||||
|
let batched_input =
|
||||||
|
Tensor::<TestBackend, 3>::random([1, 2, 64], Distribution::Default, &device);
|
||||||
|
|
||||||
|
let (cell_state, hidden_state) = lstm.forward(batched_input, None);
|
||||||
|
|
||||||
|
assert_eq!(cell_state.shape().dims, [1, 2, 1024]);
|
||||||
|
assert_eq!(hidden_state.shape().dims, [1, 2, 1024]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,19 +161,19 @@ impl CompilationSettings {
|
||||||
.zip(info.inputs.iter())
|
.zip(info.inputs.iter())
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.filter_map(|(pos, (desc, input))| {
|
.filter_map(|(pos, (desc, input))| {
|
||||||
let handle = &handles_inputs[pos];
|
|
||||||
|
|
||||||
if !is_contiguous(&handle.strides) {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
match desc.status {
|
match desc.status {
|
||||||
burn_tensor::repr::TensorStatus::ReadOnly => return None,
|
burn_tensor::repr::TensorStatus::ReadOnly => return None,
|
||||||
burn_tensor::repr::TensorStatus::NotInit => return None,
|
burn_tensor::repr::TensorStatus::NotInit => return None,
|
||||||
burn_tensor::repr::TensorStatus::ReadWrite => (),
|
burn_tensor::repr::TensorStatus::ReadWrite => (),
|
||||||
};
|
};
|
||||||
|
|
||||||
Some((pos, desc, input))
|
let handle = &handles_inputs[pos];
|
||||||
|
|
||||||
|
if handle.handle.can_mut() && is_contiguous(&handle.strides) {
|
||||||
|
Some((pos, desc, input))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
})
|
})
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue