Feat/lstm (#370)

This commit is contained in:
Mathias Insley 2023-06-06 11:33:22 -07:00 committed by GitHub
parent bff752b1a8
commit 8a88a868ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 519 additions and 2 deletions

View File

@ -37,8 +37,8 @@ pub struct LinearConfig {
/// `U(-k, k)`, where `k = sqrt(1 / d_input)`
#[derive(Module, Debug)]
pub struct Linear<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Option<Param<Tensor<B, 1>>>,
pub(crate) weight: Param<Tensor<B, 2>>,
pub(crate) bias: Option<Param<Tensor<B, 1>>>,
}
impl LinearConfig {

View File

@ -0,0 +1,87 @@
use crate as burn;
use crate::module::Module;
use crate::nn::Initializer;
use crate::nn::Linear;
use crate::nn::LinearConfig;
use burn_tensor::backend::Backend;
/// A GateController represents a gate in an LSTM cell. An
/// LSTM cell generally contains three gates: an input gate,
/// forget gate, and cell gate.
///
/// An Lstm gate is modeled as two linear transformations.
/// The results of these transformations are used to calculate
/// the gate's output.
#[derive(Module, Debug)]
pub struct GateController<B: Backend> {
/// Represents the affine transformation applied to input vector
pub(crate) input_transform: Linear<B>,
/// Represents the affine transformation applied to the hidden state
pub(crate) hidden_transform: Linear<B>,
}
impl<B: Backend> GateController<B> {
/// Initialize a new [gate_controller](GateController) module.
pub fn new(d_input: usize, d_output: usize, bias: bool, initializer: Initializer) -> Self {
Self {
input_transform: LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init(),
hidden_transform: LinearConfig {
d_input: d_output,
d_output,
bias,
initializer,
}
.init(),
}
}
/// Initialize a new [gate_controller](GateController) module with a [record](GateControllerRecord).
pub fn new_with(linear_config: &LinearConfig, record: GateControllerRecord<B>) -> Self {
let l1 = LinearConfig::init_with(linear_config, record.input_transform);
let l2 = LinearConfig::init_with(linear_config, record.hidden_transform);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
/// Used to initialize a gate controller with known weight layers,
/// allowing for predictable behavior. Used only for testing in
/// lstm.
#[cfg(test)]
pub fn create_with_weights(
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
input_record: crate::nn::LinearRecord<B>,
hidden_record: crate::nn::LinearRecord<B>,
) -> Self {
let l1 = LinearConfig {
d_input,
d_output,
bias,
initializer: initializer.clone(),
}
.init_with(input_record);
let l2 = LinearConfig {
d_input,
d_output,
bias,
initializer,
}
.init_with(hidden_record);
Self {
input_transform: l1,
hidden_transform: l2,
}
}
}

View File

@ -0,0 +1,322 @@
use burn_tensor::activation;
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::lstm::gate_controller;
use crate::nn::Initializer;
use crate::nn::LinearConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use super::gate_controller::GateController;
#[derive(Config)]
pub struct LstmConfig {
/// The size of the input features.
pub d_input: usize,
/// The size of the hidden state.
pub d_hidden: usize,
/// If a bias should be applied during the Lstm transformation.
pub bias: bool,
/// Lstm initializer
/// TODO: Make default Xavier initialization, which should be
/// a better choice. https://github.com/burn-rs/burn/issues/371
#[config(default = "Initializer::Uniform(0.0, 1.0)")]
pub initializer: Initializer,
/// The batch size
pub batch_size: usize,
}
/// The Lstm module. This implementation is for a unidirectional, stateless, Lstm.
#[derive(Module, Debug)]
pub struct Lstm<B: Backend> {
input_gate: GateController<B>,
forget_gate: GateController<B>,
output_gate: GateController<B>,
cell_gate: GateController<B>,
batch_size: usize,
d_hidden: usize,
}
impl LstmConfig {
/// Initialize a new [lstm](Lstm) module.
pub fn init<B: Backend>(&self) -> Lstm<B> {
let d_output = self.d_hidden;
let input_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let forget_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let output_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let cell_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
Lstm {
input_gate,
forget_gate,
output_gate,
cell_gate,
batch_size: self.batch_size,
d_hidden: self.d_hidden,
}
}
/// Initialize a new [lstm](lstm) module with a [record](LstmRecord).
pub fn init_with<B: Backend>(&self, record: LstmRecord<B>) -> Lstm<B> {
let linear_config = LinearConfig {
d_input: self.d_input,
d_output: self.d_hidden,
bias: self.bias,
initializer: self.initializer.clone(),
};
Lstm {
input_gate: gate_controller::GateController::new_with(
&linear_config,
record.input_gate,
),
forget_gate: gate_controller::GateController::new_with(
&linear_config,
record.forget_gate,
),
output_gate: gate_controller::GateController::new_with(
&linear_config,
record.output_gate,
),
cell_gate: gate_controller::GateController::new_with(&linear_config, record.cell_gate),
batch_size: self.batch_size,
d_hidden: self.d_hidden,
}
}
}
impl<B: Backend> Lstm<B> {
/// Applies the forward pass on the input tensor. This LSTM implementation
/// returns the cell state and hidden state for each element in a sequence (i.e., across `seq_length`),
/// producing 3-dimensional tensors where the dimensions represent [batch_size, seq_length, hidden_size].
///
/// Parameters:
/// batched_input: The input tensor of shape [batch_size, seq_length, input_size].
/// state: An optional tuple of tensors representing the initial cell state and hidden state.
/// Each state tensor has shape [batch_size, hidden_size].
/// If no initial state is provided, these tensors are initialized to zeros.
///
/// Returns:
/// A tuple of tensors, where the first tensor represents the cell states and
/// the second tensor represents the hidden states for each sequence element.
/// Both output tensors have the shape [batch_size, seq_length, hidden_size].
pub fn forward(
&mut self,
batched_input: Tensor<B, 3>,
state: Option<(Tensor<B, 2>, Tensor<B, 2>)>,
) -> (Tensor<B, 3>, Tensor<B, 3>) {
let seq_length = batched_input.shape().dims[1];
let mut batched_cell_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]);
let mut batched_hidden_state = Tensor::zeros([self.batch_size, seq_length, self.d_hidden]);
let (mut cell_state, mut hidden_state) = match state {
Some((cell_state, hidden_state)) => (cell_state, hidden_state),
None => (
Tensor::zeros([self.batch_size, self.d_hidden]),
Tensor::zeros([self.batch_size, self.d_hidden]),
),
};
for t in 0..seq_length {
let indices = Tensor::arange(t..t + 1);
let input_t = batched_input.clone().index_select(1, indices).squeeze(1);
// f(orget)g(ate) tensors
let biased_fg_input_sum = self.gate_product(&input_t, &hidden_state, &self.forget_gate);
let forget_values = activation::sigmoid(biased_fg_input_sum); // to multiply with cell state
// i(nput)g(ate) tensors
let biased_ig_input_sum = self.gate_product(&input_t, &hidden_state, &self.input_gate);
let add_values = activation::sigmoid(biased_ig_input_sum);
// o(utput)g(ate) tensors
let biased_og_input_sum = self.gate_product(&input_t, &hidden_state, &self.output_gate);
let output_values = activation::sigmoid(biased_og_input_sum);
// c(ell)g(ate) tensors
let biased_cg_input_sum = self.gate_product(&input_t, &hidden_state, &self.cell_gate);
let candidate_cell_values = biased_cg_input_sum.tanh();
cell_state = forget_values * cell_state.clone() + add_values * candidate_cell_values;
hidden_state = output_values * cell_state.clone().tanh();
// store the state for this timestep
batched_cell_state = batched_cell_state.index_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
cell_state.clone().unsqueeze(),
);
batched_hidden_state = batched_hidden_state.index_assign(
[0..self.batch_size, t..(t + 1), 0..self.d_hidden],
hidden_state.clone().unsqueeze(),
);
}
(batched_cell_state, batched_hidden_state)
}
/// Helper function for performing weighted matrix product for a gate and adds
/// bias, if any.
///
/// Mathematically, performs `Wx*X + Wh*H + b`, where:
/// Wx = weight matrix for the connection to input vector X
/// Wh = weight matrix for the connection to hidden state H
/// X = input vector
/// H = hidden state
/// b = bias terms
fn gate_product(
&self,
input: &Tensor<B, 2>,
hidden: &Tensor<B, 2>,
gate: &GateController<B>,
) -> Tensor<B, 2> {
let input_product = input.clone().matmul(gate.input_transform.weight.val());
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
let input_bias = gate
.input_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
let hidden_bias = gate
.hidden_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
match (input_bias, hidden_bias) {
(Some(input_bias), Some(hidden_bias)) => {
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
}
(Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
(None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
(None, None) => input_product + hidden_product,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::Data;
#[test]
fn initializer_default() {
TestBackend::seed(0);
let config = LstmConfig::new(5, 5, false, 2);
let lstm = config.init::<TestBackend>();
lstm.input_gate
.input_transform
.weight
.val()
.to_data()
.assert_in_range(0.0, 1.0);
lstm.forget_gate
.input_transform
.weight
.val()
.to_data()
.assert_in_range(0.0, 1.0);
lstm.output_gate
.input_transform
.weight
.val()
.to_data()
.assert_in_range(0.0, 1.0);
lstm.cell_gate
.input_transform
.weight
.val()
.to_data()
.assert_in_range(0.0, 1.0);
}
/// Test forward pass with simple input vector
///
/// f_t = sigmoid(0.7*0 + 0.8*0) = 0.5
/// i_t = sigmoid(0.5*0.1 + 0.6*0) = sigmoid(0.05) = 0.5123725
/// o_t = sigmoid(1.1*0.1 + 1.2*0) = sigmoid(0.11) = 0.5274723
/// c_t = tanh(0.9*0.1 + 1.0*0) = tanh(0.09) = 0.0892937
/// C_t = f_t * 0 + i_t * c_t = 0 + 0.5123725 * 0.0892937 = 0.04575243
/// h_t = o_t * tanh(C_t) = 0.5274723 * tanh(0.04575243) = 0.5274723 * 0.04568173 = 0.024083648
#[test]
fn test_forward_single_input_single_feature() {
TestBackend::seed(0);
let config = LstmConfig::new(1, 1, false, 1);
let mut lstm = config.init::<TestBackend>();
fn create_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
) -> GateController<TestBackend> {
let record = LinearRecord {
weight: Param::from(Tensor::from_data(Data::from([[weights]]))),
bias: Some(Param::from(Tensor::from_data(Data::from([biases])))),
};
gate_controller::GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record.clone(),
record,
)
}
lstm.input_gate =
create_gate_controller(0.5, 0.0, 1, 1, false, Initializer::UniformDefault);
lstm.forget_gate =
create_gate_controller(0.7, 0.0, 1, 1, false, Initializer::UniformDefault);
lstm.cell_gate = create_gate_controller(0.9, 0.0, 1, 1, false, Initializer::UniformDefault);
lstm.output_gate =
create_gate_controller(1.1, 0.0, 1, 1, false, Initializer::UniformDefault);
// single timestep with single feature
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
let (cell_state_batch, hidden_state_batch) = lstm.forward(input, None);
let cell_state = cell_state_batch
.index_select(0, Tensor::arange(0..1))
.squeeze(0);
let hidden_state = hidden_state_batch
.index_select(0, Tensor::arange(0..1))
.squeeze(0);
cell_state
.to_data()
.assert_approx_eq(&Data::from([[0.046]]), 3);
hidden_state
.to_data()
.assert_approx_eq(&Data::from([[0.024]]), 3)
}
}

View File

@ -0,0 +1,6 @@
mod gate_controller;
#[allow(clippy::module_inception)]
pub mod lstm;
pub use gate_controller::*;
pub use lstm::*;

View File

@ -10,6 +10,7 @@ mod embedding;
mod gelu;
mod initializer;
mod linear;
mod lstm;
mod norm;
mod relu;
@ -18,5 +19,6 @@ pub use embedding::*;
pub use gelu::*;
pub use initializer::*;
pub use linear::*;
pub use lstm::*;
pub use norm::*;
pub use relu::*;

View File

@ -110,6 +110,50 @@ where
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}
/// Squeeze the tensor along the given dimension, removing the specified dimension
/// of size one, and effectively reducing the rank of the tensor by one.
///
/// # Arguments
///
/// - `dim`: The dimension to be squeezed.
///
/// # Type Parameters
///
/// - 'D2': The resulting number of dimensions in the squeezed tensor.
///
/// # Returns
///
/// A new `Tensor<B, D2, K>` instance with the specified dimenension removed.
///
/// # Example
///
/// ```rust
///
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Shape};
///
/// fn example<B: Backend>() {
/// let tensor = Tensor::<B, 3>::ones(Shape::new([2, 1, 4]));
///
/// // Given a 3D tensor with dimensions (2, 1, 4), squeeze the dimension 1
/// let squeezed_tensor: Tensor::<B, 2> = tensor.squeeze(1);
///
/// // Resulting tensor will have dimensions (2, 4)
/// println!("{:?}", squeezed_tensor.shape());
/// }
/// ```
pub fn squeeze<const D2: usize>(self, dim: usize) -> Tensor<B, D2, K> {
check!(TensorCheck::squeeze::<D2>(dim, &self.shape().dims));
let current_dims = self.shape().dims;
let mut new_dims: [usize; D2] = [0; D2];
new_dims[..dim].copy_from_slice(&current_dims[..dim]);
new_dims[dim..].copy_from_slice(&current_dims[dim + 1..]);
Tensor::new(K::reshape::<D, D2>(self.primitive, new_dims.into()))
}
/// Unsqueeze the current tensor. Create new dimensions to fit the given size.
///
/// # Panics

View File

@ -133,6 +133,23 @@ impl TensorCheck {
check
}
pub(crate) fn squeeze<const D2: usize>(dim: usize, tensor_dims: &[usize]) -> Self {
let mut check = Self::Ok;
// This should actually be to check that the dimension to squeeze
// has a size of 1
if tensor_dims[dim] != 1 {
check = check.register(
"Squeeze",
TensorError::new(format!(
"Can't squeeze dimension {} because its size is not 1",
dim
)),
);
}
check
}
pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
let mut check = Self::Ok;
if D2 < D1 {

View File

@ -47,6 +47,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_reshape!();
burn_tensor::testgen_flatten!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_squeeze!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_sub!();
burn_tensor::testgen_transpose!();

View File

@ -22,6 +22,7 @@ mod repeat;
mod reshape;
mod sin;
mod sqrt;
mod squeeze;
mod sub;
mod tanh;
mod transpose;

View File

@ -0,0 +1,37 @@
#[burn_tensor_testgen::testgen(squeeze)]
mod tests {
use super::*;
use burn_tensor::{Data, Shape, Tensor};
/// Test if the function can successfully squeeze the size 1 dimension of a 3D tensor.
#[test]
fn should_squeeze() {
let tensor = Tensor::<TestBackend, 3>::ones(Shape::new([2, 1, 4]));
let squeezed_tensor: Tensor<TestBackend, 2> = tensor.squeeze(1);
let expected_shape = Shape::new([2, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
/// Test if the function can successfully squeeze the first size 1 dimension of a 4D tensor.
#[test]
fn should_squeeze_first() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([1, 3, 4, 5]));
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(0);
let expected_shape = Shape::new([3, 4, 5]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
/// Test if the function can successfully squeeze the last size 1 dimension of a 4D tensor.
#[test]
fn should_squeeze_last() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 1]));
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(3);
let expected_shape = Shape::new([2, 3, 4]);
assert_eq!(squeezed_tensor.shape(), expected_shape);
}
/// Test if the function panics when the squeezed dimension is not of size 1.
#[test]
#[should_panic]
fn should_squeeze_panic() {
let tensor = Tensor::<TestBackend, 4>::ones(Shape::new([2, 3, 4, 5]));
let squeezed_tensor: Tensor<TestBackend, 3> = tensor.squeeze(2);
}
}