Bug/lstm unsqueeze (#873)

This commit is contained in:
Mathias Insley 2023-10-18 15:53:37 -07:00 committed by GitHub
parent dd4e72a98f
commit 1962c06c21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 5 deletions

View File

@ -140,9 +140,12 @@ impl<B: Backend> Gru<B> {
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) // (1 - z(t)) = -(z(t) - 1)
+ update_values.clone().mul(hidden_t);
let current_shape = state_vector.shape().dims;
let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
hidden_state = hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
state_vector.clone().unsqueeze(),
reshaped_state_vector,
);
}
@ -193,7 +196,7 @@ impl<B: Backend> Gru<B> {
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::Data;
use burn_tensor::{Data, Distribution};
/// Test forward pass with simple input vector.
///
@ -263,4 +266,14 @@ mod tests {
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
}
#[test]
fn test_batched_forward_pass() {
let gru = GruConfig::new(64, 1024, true).init::<TestBackend>();
let batched_input = Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default);
let hidden_state = gru.forward(batched_input, None);
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
}
}

View File

@ -155,14 +155,19 @@ impl<B: Backend> Lstm<B> {
cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
hidden_state = output_values * cell_state.clone().tanh();
let unsqueezed_shape = [cell_state.shape().dims[0], 1, cell_state.shape().dims[1]];
let unsqueezed_cell_state = cell_state.clone().reshape(unsqueezed_shape);
let unsqueezed_hidden_state = hidden_state.clone().reshape(unsqueezed_shape);
// store the state for this timestep
batched_cell_state = batched_cell_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
cell_state.clone().unsqueeze(),
unsqueezed_cell_state.clone(),
);
batched_hidden_state = batched_hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
hidden_state.clone().unsqueeze(),
unsqueezed_hidden_state.clone(),
);
}
@ -213,7 +218,7 @@ impl<B: Backend> Lstm<B> {
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::Data;
use burn_tensor::{Data, Distribution};
#[test]
fn test_with_uniform_initializer() {
@ -317,4 +322,15 @@ mod tests {
.to_data()
.assert_approx_eq(&Data::from([[0.024]]), 3)
}
#[test]
fn test_batched_forward_pass() {
let lstm = LstmConfig::new(64, 1024, true).init::<TestBackend>();
let batched_input = Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default);
let (cell_state, hidden_state) = lstm.forward(batched_input, None);
assert_eq!(cell_state.shape().dims, [8, 10, 1024]);
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
}
}