mirror of https://github.com/tracel-ai/burn.git
Bug/lstm unsqueeze (#873)
This commit is contained in:
parent
dd4e72a98f
commit
1962c06c21
|
@ -140,9 +140,12 @@ impl<B: Backend> Gru<B> {
|
|||
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
|
||||
+ update_values.clone().mul(hidden_t);
|
||||
|
||||
let current_shape = state_vector.shape().dims;
|
||||
let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
|
||||
let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
|
||||
hidden_state = hidden_state.slice_assign(
|
||||
[0..batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
state_vector.clone().unsqueeze(),
|
||||
reshaped_state_vector,
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -193,7 +196,7 @@ impl<B: Backend> Gru<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::Data;
|
||||
use burn_tensor::{Data, Distribution};
|
||||
|
||||
/// Test forward pass with simple input vector.
|
||||
///
|
||||
|
@ -263,4 +266,14 @@ mod tests {
|
|||
|
||||
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass() {
|
||||
let gru = GruConfig::new(64, 1024, true).init::<TestBackend>();
|
||||
let batched_input = Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default);
|
||||
|
||||
let hidden_state = gru.forward(batched_input, None);
|
||||
|
||||
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -155,14 +155,19 @@ impl<B: Backend> Lstm<B> {
|
|||
cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
|
||||
hidden_state = output_values * cell_state.clone().tanh();
|
||||
|
||||
let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]];
|
||||
|
||||
let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape);
|
||||
let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape);
|
||||
|
||||
// store the state for this timestep
|
||||
batched_cell_state = batched_cell_state.slice_assign(
|
||||
[0..batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
cell_state.clone().unsqueeze(),
|
||||
unsqueezed_cell_state.clone(),
|
||||
);
|
||||
batched_hidden_state = batched_hidden_state.slice_assign(
|
||||
[0..batch_size, t..(t + 1), 0..self.d_hidden],
|
||||
hidden_state.clone().unsqueeze(),
|
||||
unsqueezed_hidden_state.clone(),
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -213,7 +218,7 @@ impl<B: Backend> Lstm<B> {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::{module::Param, nn::LinearRecord, TestBackend};
|
||||
use burn_tensor::Data;
|
||||
use burn_tensor::{Data, Distribution};
|
||||
|
||||
#[test]
|
||||
fn test_with_uniform_initializer() {
|
||||
|
@ -317,4 +322,15 @@ mod tests {
|
|||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[0.024]]), 3)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batched_forward_pass() {
|
||||
let lstm = LstmConfig::new(64, 1024, true).init::<TestBackend>();
|
||||
let batched_input = Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default);
|
||||
|
||||
let (cell_state, hidden_state) = lstm.forward(batched_input, None);
|
||||
|
||||
assert_eq!(cell_state.shape().dims, [8, 10, 1024]);
|
||||
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue